trivial-factor-weight.h
Go to the documentation of this file.
1 // fstext/trivial-factor-weight.h
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 //
20 //
21 // This is a modified file from the OpenFST Library v1.2.7 available at
22 // http://www.openfst.org and released under the Apache License Version 2.0.
23 //
24 //
25 // See ../../COPYING for clarification regarding multiple authors
26 //
27 // Licensed under the Apache License, Version 2.0 (the "License");
28 // you may not use this file except in compliance with the License.
29 // You may obtain a copy of the License at
30 //
31 // http://www.apache.org/licenses/LICENSE-2.0
32 //
33 // Unless required by applicable law or agreed to in writing, software
34 // distributed under the License is distributed on an "AS IS" BASIS,
35 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36 // See the License for the specific language governing permissions and
37 // limitations under the License.
38 //
39 // Copyright 2005-2010 Google, Inc.
40 // Author: allauzen@google.com (Cyril Allauzen)
41 
42 
43 #ifndef KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_
44 #define KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_
45 
46 
47 // TrivialFactorWeight.h This is an extension to factor-weight.h in the OpenFst
48 // code. It is a version of FactorWeight that creates separate states (with
49 // input epsilons) rather than pushing the factors forward. This is for
50 // converting from Gallic FSTs, where you want the result to be a bit more
51 // trivial with input epsilons inserted where there are multiple output symbols.
52 // This has the advantage that it always works, for any input (also I just
53 // prefer this approach).
54 
55 #include <unordered_map>
56 using std::unordered_map;
57 
58 #include <algorithm>
59 #include <string>
60 #include <utility>
61 #include <vector>
62 
63 #include <fst/cache.h>
64 #include <fst/test-properties.h>
65 
66 namespace fst {
67 
68 
69 template <class Arc>
71  typedef typename Arc::Label Label;
72  float delta;
73  Label extra_ilabel; // input label of extra arcs
74  Label extra_olabel; // output label of extra arcs
75 
77  Label il = 0, Label ol = 0)
78  : CacheOptions(opts), delta(d), extra_ilabel(il), extra_olabel(ol) {}
79 
81  float d, Label il = 0, Label ol = 0)
82  : delta(d), extra_ilabel(il), extra_olabel(ol) {}
83 
84  TrivialFactorWeightOptions(): delta(kDelta), extra_ilabel(0), extra_olabel(0) {}
85 
86 };
87 
88 namespace internal {
89 
90 // Implementation class for TrivialFactorWeight
91 template <class A, class F>
93  : public CacheImpl<A> {
94  public:
96  using FstImpl<A>::SetType;
97  using FstImpl<A>::SetProperties;
98  using FstImpl<A>::Properties;
99  using FstImpl<A>::SetInputSymbols;
100  using FstImpl<A>::SetOutputSymbols;
101 
102  using CacheBaseImpl< CacheState<A> >::HasStart;
103  using CacheBaseImpl< CacheState<A> >::HasFinal;
104  using CacheBaseImpl< CacheState<A> >::HasArcs;
105 
106  typedef A Arc;
107  typedef typename A::Label Label;
108  typedef typename A::Weight Weight;
109  typedef typename A::StateId StateId;
110  typedef F FactorIterator;
111 
112  typedef DefaultCacheStore<A> Store;
113  typedef typename Store::State State;
114 
115  struct Element {
116  Element() {}
117 
118  Element(StateId s, Weight w) : state(s), weight(w) {}
119 
120  StateId state; // Input state Id
121  Weight weight; // Residual weight
122  };
123 
125  : CacheImpl<A>(opts),
126  fst_(fst.Copy()),
127  delta_(opts.delta),
128  extra_ilabel_(opts.extra_ilabel),
129  extra_olabel_(opts.extra_olabel) {
130  SetType("factor-weight");
131  uint64 props = fst.Properties(kFstProperties, false);
132  SetProperties(FactorWeightProperties(props), kCopyProperties);
133 
134  SetInputSymbols(fst.InputSymbols());
135  SetOutputSymbols(fst.OutputSymbols());
136  }
137 
139  : CacheImpl<A>(impl),
140  fst_(impl.fst_->Copy(true)),
141  delta_(impl.delta_),
142  extra_ilabel_(impl.extra_ilabel_),
143  extra_olabel_(impl.extra_olabel_) {
144  SetType("factor-weight");
145  SetProperties(impl.Properties(), kCopyProperties);
146  SetInputSymbols(impl.InputSymbols());
147  SetOutputSymbols(impl.OutputSymbols());
148  }
149 
150  StateId Start() {
151  if (!HasStart()) {
152  StateId s = fst_->Start();
153  if (s == kNoStateId)
154  return kNoStateId;
155  StateId start = this->FindState(Element(fst_->Start(), Weight::One()));
156  this->SetStart(start);
157  }
158  return CacheImpl<A>::Start();
159  }
160 
161  Weight Final(StateId s) {
162  if (!HasFinal(s)) {
163  const Element &e = elements_[s];
164  Weight w;
165  if (e.state == kNoStateId) { // extra state inserted to represent final weights.
166  FactorIterator fit(e.weight);
167  if (fit.Done()) { // cannot be factored.
168  w = e.weight; // so it's final
169  } else {
170  w = Weight::Zero(); // need another transition.
171  }
172  } else {
173  if (e.weight != Weight::One()) { // Not a real state.
174  w = Weight::Zero();
175  } else { // corresponds to a "real" state.
176  w = fst_->Final(e.state);
177  FactorIterator fit(w);
178  if (!fit.Done()) // we would have intermediate states representing this final state.
179  w = Weight::Zero();
180  }
181  }
182  this->SetFinal(s, w);
183  return w;
184  } else {
185  return CacheImpl<A>::Final(s);
186  }
187  }
188 
189  size_t NumArcs(StateId s) {
190  if (!HasArcs(s))
191  Expand(s);
192  return CacheImpl<A>::NumArcs(s);
193  }
194 
195  size_t NumInputEpsilons(StateId s) {
196  if (!HasArcs(s))
197  Expand(s);
199  }
200 
201  size_t NumOutputEpsilons(StateId s) {
202  if (!HasArcs(s))
203  Expand(s);
205  }
206 
207  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
208  if (!HasArcs(s))
209  Expand(s);
211  }
212 
213 
214  // Find state corresponding to an element. Create new state
215  // if element not found.
216  StateId FindState(const Element &e) {
217  typename ElementMap::iterator eit = element_map_.find(e);
218  if (eit != element_map_.end()) {
219  return (*eit).second;
220  } else {
221  StateId s = elements_.size();
222  elements_.push_back(e);
223  element_map_.insert(std::pair<const Element, StateId>(e, s));
224  return s;
225  }
226  }
227 
228  // Computes the outgoing transitions from a state, creating new destination
229  // states as needed.
230  void Expand(StateId s) {
231  CHECK(static_cast<size_t>(s) < elements_.size());
232  Element e = elements_[s];
233  if (e.weight != Weight::One()) {
234  FactorIterator fit(e.weight);
235  if (fit.Done()) { // Cannot be factored-> create a link to dest state directly
236  if (e.state != kNoStateId) {
237  StateId dest = FindState(Element(e.state, Weight::One()));
238  PushArc(s, Arc(extra_ilabel_, extra_olabel_, e.weight, dest));
239  } // else we're done. This is a final state.
240  } else { // Can be factored.
241  const std::pair<Weight, Weight> &p = fit.Value();
242  StateId dest = FindState(Element(e.state, p.second.Quantize(delta_)));
243  PushArc(s, Arc(extra_ilabel_, extra_olabel_, p.first, dest));
244  }
245  } else { // Unit weight. This corresponds to a "real" state.
246  CHECK(e.state != kNoStateId);
247  for (ArcIterator< Fst<A> > ait(*fst_, e.state);
248  !ait.Done();
249  ait.Next()) {
250  const A &arc = ait.Value();
251  FactorIterator fit(arc.weight);
252  if (fit.Done()) { // cannot be factored->just link directly to dest.
253  StateId dest = FindState(Element(arc.nextstate, Weight::One()));
254  PushArc(s, Arc(arc.ilabel, arc.olabel, arc.weight, dest));
255  } else {
256  const std::pair<Weight, Weight> &p = fit.Value();
257  StateId dest = FindState(Element(arc.nextstate, p.second.Quantize(delta_)));
258  PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, dest));
259  }
260  }
261  // See if we have to add arcs for final-states [only if final-weight is factorable].
262  Weight final_w = fst_->Final(e.state);
263  if (final_w != Weight::Zero()) {
264  FactorIterator fit(final_w);
265  if (!fit.Done()) {
266  const std::pair<Weight, Weight> &p = fit.Value();
267  StateId dest = FindState(Element(kNoStateId, p.second.Quantize(delta_)));
268  PushArc(s, Arc(extra_ilabel_, extra_olabel_, p.first, dest));
269  }
270  }
271  }
272  this->SetArcs(s);
273  }
274 
275  private:
276  // Equality function for Elements, assume weights have been quantized.
277  class ElementEqual {
278  public:
279  bool operator()(const Element &x, const Element &y) const {
280  return x.state == y.state && x.weight == y.weight;
281  }
282  };
283 
284  // Hash function for Elements to Fst states.
285  class ElementKey {
286  public:
287  size_t operator()(const Element &x) const {
288  return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
289  }
290  private:
291  static const int kPrime = 7853;
292  };
293 
294  typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
295 
296  std::unique_ptr<const Fst<A>> fst_;
297  float delta_;
298  uint32 mode_; // factoring arc and/or final weights
299  Label extra_ilabel_; // ilabel of arc created when factoring final w's
300  Label extra_olabel_; // olabel of arc created when factoring final w's
301  std::vector<Element> elements_; // mapping Fst state to Elements
302  ElementMap element_map_; // mapping Elements to Fst state
303 
304 };
305 
306 } // namespace internal
307 
319 
322 
323 
324 template <class A, class F>
326  public ImplToFst<internal::TrivialFactorWeightFstImpl<A, F>> {
327  public:
328  friend class ArcIterator< TrivialFactorWeightFst<A, F> >;
329  friend class StateIterator< TrivialFactorWeightFst<A, F> >;
330 
331  typedef A Arc;
332  typedef typename A::Weight Weight;
333  typedef typename A::StateId StateId;
334  typedef DefaultCacheStore<Arc> Store;
335  typedef typename Store::State State;
337 
338  explicit TrivialFactorWeightFst(const Fst<A> &fst)
339  : ImplToFst<Impl>(std::make_shared<Impl>(fst, TrivialFactorWeightOptions<A>())) {}
340 
342  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
343 
344  // See Fst<>::Copy() for doc.
346  : ImplToFst<Impl>(fst, copy) {}
347 
348  // Get a copy of this TrivialFactorWeightFst. See Fst<>::Copy() for further doc.
349  TrivialFactorWeightFst<A, F> *Copy(bool copy = false) const override {
350  return new TrivialFactorWeightFst<A, F>(*this, copy);
351  }
352 
353  inline void InitStateIterator(StateIteratorData<A> *data) const override;
354 
355  void InitArcIterator(StateId s, ArcIteratorData<A> *data) const override {
356  GetMutableImpl()->InitArcIterator(s, data);
357  }
358 
359  private:
362 
363  TrivialFactorWeightFst &operator=(const TrivialFactorWeightFst &fst) = delete;
364 };
365 
366 
367 // Specialization for TrivialFactorWeightFst.
368 template<class A, class F>
369 class StateIterator< TrivialFactorWeightFst<A, F> >
370  : public CacheStateIterator< TrivialFactorWeightFst<A, F> > {
371  public:
373  : CacheStateIterator< TrivialFactorWeightFst<A, F> >(fst, fst.GetMutableImpl()) {}
374 };
375 
376 
377 // Specialization for TrivialFactorWeightFst.
378 template <class A, class F>
379 class ArcIterator< TrivialFactorWeightFst<A, F> >
380  : public CacheArcIterator< TrivialFactorWeightFst<A, F> > {
381  public:
382  typedef typename A::StateId StateId;
383 
385  : CacheArcIterator< TrivialFactorWeightFst<A, F>>(fst.GetMutableImpl(), s) {
386  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
387  }
388 };
389 
390 template <class A, class F>
392  StateIteratorData<A> *data) const {
393  data->base = new StateIterator< TrivialFactorWeightFst<A, F> >(*this);
394 }
395 
396 
397 
398 
399 } // namespace fst
400 
401 #endif
fst::StdArc::StateId StateId
StateIterator(const TrivialFactorWeightFst< A, F > &fst)
TrivialFactorWeightOptions(float d, Label il=0, Label ol=0)
TrivialFactorWeightOptions(const CacheOptions &opts, float d, Label il=0, Label ol=0)
TrivialFactorWeightFstImpl(const TrivialFactorWeightFstImpl< A, F > &impl)
TrivialFactorWeightFst takes as template parameter a FactorIterator as defined above.
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
TrivialFactorWeightFstImpl(const Fst< A > &fst, const TrivialFactorWeightOptions< A > &opts)
bool operator()(const Element &x, const Element &y) const
void InitStateIterator(StateIteratorData< A > *data) const override
TrivialFactorWeightFst(const Fst< A > &fst)
void InitArcIterator(StateId s, ArcIteratorData< A > *data)
TrivialFactorWeightFst(const Fst< A > &fst, const TrivialFactorWeightOptions< A > &opts)
unordered_map< Element, StateId, ElementKey, ElementEqual > ElementMap
std::unique_ptr< const Fst< A > > fst_
ArcIterator(const TrivialFactorWeightFst< A, F > &fst, StateId s)
TrivialFactorWeightFst(const TrivialFactorWeightFst< A, F > &fst, bool copy)
fst::StdArc::Label Label
DefaultCacheStore< Arc > Store
fst::StdArc::Weight Weight
internal::TrivialFactorWeightFstImpl< A, F > Impl
TrivialFactorWeightFst< A, F > * Copy(bool copy=false) const override
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
void Copy(const CuMatrixBase< Real > &src, const CuArray< int32 > &copy_from_indices, CuMatrixBase< Real > *tgt)
Copies elements from src into tgt as given by copy_from_indices.
Definition: cu-math.cc:173