fstext-utils-inl.h
Go to the documentation of this file.
1 // fstext/fstext-utils-inl.h
2 
3 // Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
4 // 2014 Telepoint Global Hosting Service, LLC. (Author: David Snyder)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
22 #define KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
23 #include <cstring>
24 #include "base/kaldi-common.h"
25 #include "util/stl-utils.h"
26 #include "util/text-utils.h"
27 #include "util/kaldi-io.h"
28 #include "fstext/factor.h"
29 #include "fstext/pre-determinize.h"
31 
32 #include <sstream>
33 #include <algorithm>
34 #include <string>
35 
36 namespace fst {
37 
38 
39 
40 template<class Arc>
41 typename Arc::Label HighestNumberedOutputSymbol(const Fst<Arc> &fst) {
42  typename Arc::Label ans = 0;
43  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
44  typename Arc::StateId s = siter.Value();
45  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
46  const Arc &arc = aiter.Value();
47  ans = std::max(ans, arc.olabel);
48  }
49  }
50  return ans;
51 }
52 
53 template<class Arc>
54 typename Arc::Label HighestNumberedInputSymbol(const Fst<Arc> &fst) {
55  typename Arc::Label ans = 0;
56  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
57  typename Arc::StateId s = siter.Value();
58  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
59  const Arc &arc = aiter.Value();
60  ans = std::max(ans, arc.ilabel);
61  }
62  }
63  return ans;
64 }
65 
66 template<class Arc>
67 typename Arc::StateId NumArcs(const ExpandedFst<Arc> &fst) {
68  typedef typename Arc::StateId StateId;
69  StateId num_arcs = 0;
70  for (StateId s = 0; s < fst.NumStates(); s++)
71  num_arcs += fst.NumArcs(s);
72  return num_arcs;
73 }
74 
75 template<class Arc, class I>
76 void GetOutputSymbols(const Fst<Arc> &fst,
77  bool include_eps,
78  std::vector<I> *symbols) {
80  std::set<I> all_syms;
81  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
82  typename Arc::StateId s = siter.Value();
83  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
84  const Arc &arc = aiter.Value();
85  all_syms.insert(arc.olabel);
86  }
87  }
88 
89  // Remove epsilon, if instructed.
90  if (!include_eps && !all_syms.empty() && *all_syms.begin() == 0)
91  all_syms.erase(0);
92  KALDI_ASSERT(symbols != NULL);
93  kaldi::CopySetToVector(all_syms, symbols);
94 }
95 
96 template<class Arc, class I>
97 void GetInputSymbols(const Fst<Arc> &fst,
98  bool include_eps,
99  std::vector<I> *symbols) {
101  unordered_set<I> all_syms;
102  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
103  typename Arc::StateId s = siter.Value();
104  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
105  const Arc &arc = aiter.Value();
106  all_syms.insert(arc.ilabel);
107  }
108  }
109  // Remove epsilon, if instructed.
110  if (!include_eps && all_syms.count(0) != 0)
111  all_syms.erase(0);
112  KALDI_ASSERT(symbols != NULL);
113  kaldi::CopySetToVector(all_syms, symbols);
114  std::sort(symbols->begin(), symbols->end());
115 }
116 
117 
118 template<class Arc, class I>
119 void RemoveSomeInputSymbols(const std::vector<I> &to_remove,
120  MutableFst<Arc> *fst) {
122  RemoveSomeInputSymbolsMapper<Arc, I> mapper(to_remove);
123  Map(fst, mapper);
124 }
125 
126 template<class Arc, class I>
128  public:
129  Arc operator ()(const Arc &arc_in) {
130  Arc ans = arc_in;
131  if (ans.ilabel > 0 &&
132  ans.ilabel < static_cast<typename Arc::Label>((*symbol_mapping_).size()))
133  ans.ilabel = (*symbol_mapping_)[ans.ilabel];
134  return ans;
135  }
136  MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
137  MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
138  MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
139  uint64 Properties(uint64 props) const { // Not tested.
140  bool remove_epsilons = (symbol_mapping_->size() > 0 && (*symbol_mapping_)[0] != 0);
141  bool add_epsilons = (symbol_mapping_->size() > 1 &&
142  *std::min_element(symbol_mapping_->begin()+1, symbol_mapping_->end()) == 0);
143 
144  // remove the following as we don't know now if any of them are true.
145  uint64 props_to_remove = kAcceptor|kNotAcceptor|kIDeterministic|kNonIDeterministic|
146  kILabelSorted|kNotILabelSorted;
147  if (remove_epsilons) props_to_remove |= kEpsilons|kIEpsilons;
148  if (add_epsilons) props_to_remove |= kNoEpsilons|kNoIEpsilons;
149  uint64 props_to_add = 0;
150  if (remove_epsilons && !add_epsilons) props_to_add |= kNoEpsilons|kNoIEpsilons;
151  return (props & ~props_to_remove) | props_to_add;
152  }
153  // initialize with copy = false only if the "to_remove" argument will not be deleted
154  // in the lifetime of this object.
155  MapInputSymbolsMapper(const std::vector<I> &to_remove, bool copy) {
157  if (copy) symbol_mapping_ = new std::vector<I> (to_remove);
158  else symbol_mapping_ = &to_remove;
159  owned = copy;
160  }
162  private:
163  bool owned;
164  const std::vector<I> *symbol_mapping_;
165 };
166 
167 template<class Arc, class I>
168 void MapInputSymbols(const std::vector<I> &symbol_mapping,
169  MutableFst<Arc> *fst) {
171  // false == don't copy the "symbol_mapping", retain pointer--
172  // safe since short-lived object.
173  MapInputSymbolsMapper<Arc, I> mapper(symbol_mapping, false);
174  Map(fst, mapper);
175 }
176 
177 template<class Arc, class I>
178 bool GetLinearSymbolSequence(const Fst<Arc> &fst,
179  std::vector<I> *isymbols_out,
180  std::vector<I> *osymbols_out,
181  typename Arc::Weight *tot_weight_out) {
182  typedef typename Arc::StateId StateId;
183  typedef typename Arc::Weight Weight;
184 
185  Weight tot_weight = Weight::One();
186  std::vector<I> ilabel_seq;
187  std::vector<I> olabel_seq;
188 
189  StateId cur_state = fst.Start();
190  if (cur_state == kNoStateId) { // empty sequence.
191  if (isymbols_out != NULL) isymbols_out->clear();
192  if (osymbols_out != NULL) osymbols_out->clear();
193  if (tot_weight_out != NULL) *tot_weight_out = Weight::Zero();
194  return true;
195  }
196  while (1) {
197  Weight w = fst.Final(cur_state);
198  if (w != Weight::Zero()) { // is final..
199  tot_weight = Times(w, tot_weight);
200  if (fst.NumArcs(cur_state) != 0) return false;
201  if (isymbols_out != NULL) *isymbols_out = ilabel_seq;
202  if (osymbols_out != NULL) *osymbols_out = olabel_seq;
203  if (tot_weight_out != NULL) *tot_weight_out = tot_weight;
204  return true;
205  } else {
206  if (fst.NumArcs(cur_state) != 1) return false;
207 
208  ArcIterator<Fst<Arc> > iter(fst, cur_state); // get the only arc.
209  const Arc &arc = iter.Value();
210  tot_weight = Times(arc.weight, tot_weight);
211  if (arc.ilabel != 0) ilabel_seq.push_back(arc.ilabel);
212  if (arc.olabel != 0) olabel_seq.push_back(arc.olabel);
213  cur_state = arc.nextstate;
214  }
215  }
216 }
217 
218 
219 // see fstext-utils.h for comment.
220 template<class Arc>
221 void ConvertNbestToVector(const Fst<Arc> &fst,
222  std::vector<VectorFst<Arc> > *fsts_out) {
223  typedef typename Arc::Weight Weight;
224  typedef typename Arc::StateId StateId;
225  fsts_out->clear();
226  StateId start_state = fst.Start();
227  if (start_state == kNoStateId) return; // No output.
228  size_t n_arcs = fst.NumArcs(start_state);
229  bool start_is_final = (fst.Final(start_state) != Weight::Zero());
230  fsts_out->reserve(n_arcs + (start_is_final ? 1 : 0));
231 
232  if (start_is_final) {
233  fsts_out->resize(fsts_out->size() + 1);
234  StateId start_state_out = fsts_out->back().AddState();
235  fsts_out->back().SetFinal(start_state_out, fst.Final(start_state));
236  }
237 
238  for (ArcIterator<Fst<Arc> > start_aiter(fst, start_state);
239  !start_aiter.Done();
240  start_aiter.Next()) {
241  fsts_out->resize(fsts_out->size() + 1);
242  VectorFst<Arc> &ofst = fsts_out->back();
243  const Arc &first_arc = start_aiter.Value();
244  StateId cur_state = start_state,
245  cur_ostate = ofst.AddState();
246  ofst.SetStart(cur_ostate);
247  StateId next_ostate = ofst.AddState();
248  ofst.AddArc(cur_ostate, Arc(first_arc.ilabel, first_arc.olabel,
249  first_arc.weight, next_ostate));
250  cur_state = first_arc.nextstate;
251  cur_ostate = next_ostate;
252  while (1) {
253  size_t this_n_arcs = fst.NumArcs(cur_state);
254  KALDI_ASSERT(this_n_arcs <= 1); // or it violates our assumptions
255  // about the input.
256  if (this_n_arcs == 1) {
257  KALDI_ASSERT(fst.Final(cur_state) == Weight::Zero());
258  // or problem with ShortestPath.
259  ArcIterator<Fst<Arc> > aiter(fst, cur_state);
260  const Arc &arc = aiter.Value();
261  next_ostate = ofst.AddState();
262  ofst.AddArc(cur_ostate, Arc(arc.ilabel, arc.olabel,
263  arc.weight, next_ostate));
264  cur_state = arc.nextstate;
265  cur_ostate = next_ostate;
266  } else {
267  KALDI_ASSERT(fst.Final(cur_state) != Weight::Zero());
268  // or problem with ShortestPath.
269  ofst.SetFinal(cur_ostate, fst.Final(cur_state));
270  break;
271  }
272  }
273  }
274 }
275 
276 
277 // see fstext-utils.sh for comment.
278 template<class Arc>
279 void NbestAsFsts(const Fst<Arc> &fst,
280  size_t n,
281  std::vector<VectorFst<Arc> > *fsts_out) {
282  KALDI_ASSERT(n > 0);
283  KALDI_ASSERT(fsts_out != NULL);
284  VectorFst<Arc> nbest_fst;
285  ShortestPath(fst, &nbest_fst, n);
286  ConvertNbestToVector(nbest_fst, fsts_out);
287 }
288 
289 template<class Arc, class I>
290 void MakeLinearAcceptorWithAlternatives(const std::vector<std::vector<I> > &labels,
291  MutableFst<Arc> *ofst) {
292  typedef typename Arc::StateId StateId;
293  typedef typename Arc::Weight Weight;
294 
295  ofst->DeleteStates();
296  StateId cur_state = ofst->AddState();
297  ofst->SetStart(cur_state);
298  for (size_t i = 0; i < labels.size(); i++) {
299  KALDI_ASSERT(labels[i].size() != 0);
300  StateId next_state = ofst->AddState();
301  for (size_t j = 0; j < labels[i].size(); j++) {
302  Arc arc(labels[i][j], labels[i][j], Weight::One(), next_state);
303  ofst->AddArc(cur_state, arc);
304  }
305  cur_state = next_state;
306  }
307  ofst->SetFinal(cur_state, Weight::One());
308 }
309 
310 template<class Arc, class I>
311 void MakeLinearAcceptor(const std::vector<I> &labels, MutableFst<Arc> *ofst) {
312  typedef typename Arc::StateId StateId;
313  typedef typename Arc::Weight Weight;
314 
315  ofst->DeleteStates();
316  StateId cur_state = ofst->AddState();
317  ofst->SetStart(cur_state);
318  for (size_t i = 0; i < labels.size(); i++) {
319  StateId next_state = ofst->AddState();
320  Arc arc(labels[i], labels[i], Weight::One(), next_state);
321  ofst->AddArc(cur_state, arc);
322  cur_state = next_state;
323  }
324  ofst->SetFinal(cur_state, Weight::One());
325 }
326 
327 
328 template<class I>
329 void GetSymbols(const SymbolTable &symtab,
330  bool include_eps,
331  std::vector<I> *syms_out) {
332  KALDI_ASSERT(syms_out != NULL);
333  syms_out->clear();
334  for (SymbolTableIterator iter(symtab);
335  !iter.Done();
336  iter.Next()) {
337  if (include_eps || iter.Value() != 0) {
338  syms_out->push_back(iter.Value());
339  KALDI_ASSERT(syms_out->back() == iter.Value()); // an integer-range thing.
340  }
341  }
342 }
343 
344 template<class Arc>
345 void SafeDeterminizeWrapper(MutableFst<Arc> *ifst, MutableFst<Arc> *ofst, float delta) {
346  typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
347  std::vector<typename Arc::Label> extra_syms;
348  PreDeterminize(ifst,
349  (typename Arc::Label)(highest_sym+1),
350  &extra_syms);
351  DeterminizeStar(*ifst, ofst, delta);
352  RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
353 }
354 
355 
356 template<class Arc>
357 void SafeDeterminizeMinimizeWrapper(MutableFst<Arc> *ifst, VectorFst<Arc> *ofst, float delta) {
358  typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
359  std::vector<typename Arc::Label> extra_syms;
360  PreDeterminize(ifst,
361  (typename Arc::Label)(highest_sym+1),
362  &extra_syms);
363  DeterminizeStar(*ifst, ofst, delta);
364  RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
365  RemoveEpsLocal(ofst); // this is "safe" and will never hurt.
366  MinimizeEncoded(ofst, delta);
367 }
368 
369 
370 inline
371 void DeterminizeStarInLog(VectorFst<StdArc> *fst, float delta, bool *debug_ptr, int max_states) {
372  // DeterminizeStarInLog determinizes 'fst' in the log semiring, using
373  // the DeterminizeStar algorithm (which also removes epsilons).
374 
375  ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
376  VectorFst<LogArc> *fst_log = new VectorFst<LogArc>; // Want to determinize in log semiring.
377  Cast(*fst, fst_log);
378  VectorFst<StdArc> tmp;
379  *fst = tmp; // make fst empty to free up memory. [actually may make no difference..]
380  VectorFst<LogArc> *fst_det_log = new VectorFst<LogArc>;
381  DeterminizeStar(*fst_log, fst_det_log, delta, debug_ptr, max_states);
382  Cast(*fst_det_log, fst);
383  delete fst_log;
384  delete fst_det_log;
385 }
386 
387 inline
388 void DeterminizeInLog(VectorFst<StdArc> *fst) {
389  // DeterminizeInLog determinizes 'fst' in the log semiring.
390 
391  ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
392  VectorFst<LogArc> *fst_log = new VectorFst<LogArc>; // Want to determinize in log semiring.
393  Cast(*fst, fst_log);
394  VectorFst<StdArc> tmp;
395  *fst = tmp; // make fst empty to free up memory. [actually may make no difference..]
396  VectorFst<LogArc> *fst_det_log = new VectorFst<LogArc>;
397  Determinize(*fst_log, fst_det_log);
398  Cast(*fst_det_log, fst);
399  delete fst_log;
400  delete fst_det_log;
401 }
402 
403 
404 
405 // make it inline to avoid having to put it in a .cc file.
406 // destructive algorithm (changes ifst as well as ofst).
407 inline
408 void SafeDeterminizeMinimizeWrapperInLog(VectorFst<StdArc> *ifst, VectorFst<StdArc> *ofst, float delta) {
409  VectorFst<LogArc> *ifst_log = new VectorFst<LogArc>; // Want to determinize in log semiring.
410  Cast(*ifst, ifst_log);
411  VectorFst<LogArc> *ofst_log = new VectorFst<LogArc>;
412  SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
413  Cast(*ofst_log, ofst);
414  delete ifst_log;
415  delete ofst_log;
416  RemoveEpsLocal(ofst); // this is "safe" and will never hurt. Do this in tropical, which is important.
417  MinimizeEncoded(ofst, delta); // Non-deterministic minimization will fail in log semiring so do it with StdARc.
418 }
419 
420 inline
421 void SafeDeterminizeWrapperInLog(VectorFst<StdArc> *ifst, VectorFst<StdArc> *ofst, float delta) {
422  VectorFst<LogArc> *ifst_log = new VectorFst<LogArc>; // Want to determinize in log semiring.
423  Cast(*ifst, ifst_log);
424  VectorFst<LogArc> *ofst_log = new VectorFst<LogArc>;
425  SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
426  Cast(*ofst_log, ofst);
427  delete ifst_log;
428  delete ofst_log;
429 }
430 
431 
432 
433 template<class Arc>
434 void RemoveWeights(MutableFst<Arc> *ifst) {
435  typedef typename Arc::StateId StateId;
436  typedef typename Arc::Weight Weight;
437 
438  for (StateIterator<MutableFst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) {
439  StateId s = siter.Value();
440  for (MutableArcIterator<MutableFst<Arc> > aiter(ifst, s); !aiter.Done(); aiter.Next()) {
441  Arc arc(aiter.Value());
442  arc.weight = Weight::One();
443  aiter.SetValue(arc);
444  }
445  if (ifst->Final(s) != Weight::Zero())
446  ifst->SetFinal(s, Weight::One());
447  }
448  ifst->SetProperties(kUnweighted, kUnweighted);
449 }
450 
451 // Used in PrecedingInputSymbolsAreSame (non-functor version), and
452 // similar routines.
453 template<class T> struct IdentityFunction {
454  typedef T Arg;
455  typedef T Result;
456  T operator () (const T &t) const { return t; }
457 };
458 
459 template<class Arc>
460 bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst<Arc> &fst) {
462  return PrecedingInputSymbolsAreSameClass(start_is_epsilon, fst, f);
463 }
464 
465 template<class Arc, class F> // F is functor type from labels to classes.
466 bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon, const Fst<Arc> &fst, const F &f) {
467  typedef typename F::Result ClassType;
468  typedef typename Arc::StateId StateId;
469  std::vector<ClassType> classes;
470  ClassType noClass = f(kNoLabel);
471 
472  if (start_is_epsilon) {
473  StateId start_state = fst.Start();
474  if (start_state < 0 || start_state == kNoStateId)
475  return true; // empty fst-- doesn't matter.
476  classes.resize(start_state+1, noClass);
477  classes[start_state] = 0;
478  }
479 
480  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
481  StateId s = siter.Value();
482  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
483  const Arc &arc = aiter.Value();
484  if (classes.size() <= arc.nextstate)
485  classes.resize(arc.nextstate+1, noClass);
486  if (classes[arc.nextstate] == noClass)
487  classes[arc.nextstate] = f(arc.ilabel);
488  else
489  if (classes[arc.nextstate] != f(arc.ilabel))
490  return false;
491  }
492  }
493  return true;
494 }
495 
496 template<class Arc>
497 bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst<Arc> &fst) {
499  return FollowingInputSymbolsAreSameClass(end_is_epsilon, fst, f);
500 }
501 
502 
503 template<class Arc, class F>
504 bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst<Arc> &fst, const F &f) {
505  typedef typename Arc::StateId StateId;
506  typedef typename Arc::Weight Weight;
507  typedef typename F::Result ClassType;
508  const ClassType noClass = f(kNoLabel), epsClass = f(0);
509  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
510  StateId s = siter.Value();
511  ClassType c = noClass;
512  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
513  const Arc &arc = aiter.Value();
514  if (c == noClass)
515  c = f(arc.ilabel);
516  else
517  if (c != f(arc.ilabel))
518  return false;
519  }
520  if (end_is_epsilon && c != noClass &&
521  c != epsClass && fst.Final(s) != Weight::Zero())
522  return false;
523  }
524  return true;
525 }
526 
527 template<class Arc>
528 void MakePrecedingInputSymbolsSame(bool start_is_epsilon, MutableFst<Arc> *fst) {
530  MakePrecedingInputSymbolsSameClass(start_is_epsilon, fst, f);
531 }
532 
533 template<class Arc, class F>
534 void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon, MutableFst<Arc> *fst, const F &f) {
535  typedef typename F::Result ClassType;
536  typedef typename Arc::StateId StateId;
537  typedef typename Arc::Weight Weight;
538  std::vector<ClassType> classes;
539  ClassType noClass = f(kNoLabel);
540  ClassType epsClass = f(0);
541  if (start_is_epsilon) { // treat having-start-state as epsilon in-transition.
542  StateId start_state = fst->Start();
543  if (start_state < 0 || start_state == kNoStateId) // empty FST.
544  return;
545  classes.resize(start_state+1, noClass);
546  classes[start_state] = epsClass;
547  }
548 
549  // Find bad states (states with multiple input-symbols into them).
550  std::set<StateId> bad_states; // states that we need to change.
551  for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
552  StateId s = siter.Value();
553  for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
554  const Arc &arc = aiter.Value();
555  if (classes.size() <= static_cast<size_t>(arc.nextstate))
556  classes.resize(arc.nextstate+1, noClass);
557  if (classes[arc.nextstate] == noClass)
558  classes[arc.nextstate] = f(arc.ilabel);
559  else
560  if (classes[arc.nextstate] != f(arc.ilabel))
561  bad_states.insert(arc.nextstate);
562  }
563  }
564  if (bad_states.empty()) return; // Nothing to do.
565  kaldi::ConstIntegerSet<StateId> bad_states_ciset(bad_states); // faster lookup.
566 
567  // Work out list of arcs we have to change as (state, arc-offset).
568  // Can't do the actual changes in this pass, since we have to add new
569  // states which invalidates the iterators.
570  std::vector<std::pair<StateId, size_t> > arcs_to_change;
571  for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
572  StateId s = siter.Value();
573  for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
574  const Arc &arc = aiter.Value();
575  if (arc.ilabel != 0 &&
576  bad_states_ciset.count(arc.nextstate) != 0)
577  arcs_to_change.push_back(std::make_pair(s, aiter.Position()));
578  }
579  }
580  KALDI_ASSERT(!arcs_to_change.empty()); // since !bad_states.empty().
581 
582  std::map<std::pair<StateId, ClassType>, StateId> state_map;
583  // state_map is a map from (bad-state, input-symbol-class) to dummy-state.
584 
585  for (size_t i = 0; i < arcs_to_change.size(); i++) {
586  StateId s = arcs_to_change[i].first;
587  ArcIterator<MutableFst<Arc> > aiter(*fst, s);
588  aiter.Seek(arcs_to_change[i].second);
589  Arc arc = aiter.Value();
590 
591  // Transition is non-eps transition to "bad" state. Introduce new state (or find
592  // existing one).
593  std::pair<StateId, ClassType> p(arc.nextstate, f(arc.ilabel));
594  if (state_map.count(p) == 0) {
595  StateId newstate = state_map[p] = fst->AddState();
596  fst->AddArc(newstate, Arc(0, 0, Weight::One(), arc.nextstate));
597  }
598  StateId dst_state = state_map[p];
599  arc.nextstate = dst_state;
600 
601  // Initialize the MutableArcIterator only now, as the call to NewState()
602  // may have invalidated the first arc iterator.
603  MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
604  maiter.Seek(arcs_to_change[i].second);
605  maiter.SetValue(arc);
606  }
607 }
608 
609 template<class Arc>
610 void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst<Arc> *fst) {
612  MakeFollowingInputSymbolsSameClass(end_is_epsilon, fst, f);
613 }
614 
615 template<class Arc, class F>
616 void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon, MutableFst<Arc> *fst, const F &f) {
617  typedef typename Arc::StateId StateId;
618  typedef typename Arc::Weight Weight;
619  typedef typename F::Result ClassType;
620  std::vector<StateId> bad_states;
621  ClassType noClass = f(kNoLabel);
622  ClassType epsClass = f(0);
623  for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
624  StateId s = siter.Value();
625  ClassType c = noClass;
626  bool bad = false;
627  for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
628  const Arc &arc = aiter.Value();
629  if (c == noClass)
630  c = f(arc.ilabel);
631  else
632  if (c != f(arc.ilabel)) {
633  bad = true;
634  break;
635  }
636  }
637  if (end_is_epsilon && c != noClass &&
638  c != epsClass && fst->Final(s) != Weight::Zero())
639  bad = true;
640  if (bad)
641  bad_states.push_back(s);
642  }
643  std::vector<Arc> my_arcs;
644  for (size_t i = 0; i < bad_states.size(); i++) {
645  StateId s = bad_states[i];
646  my_arcs.clear();
647  for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next())
648  my_arcs.push_back(aiter.Value());
649 
650  for (size_t j = 0; j < my_arcs.size(); j++) {
651  Arc &arc = my_arcs[j];
652  if (arc.ilabel != 0) {
653  StateId newstate = fst->AddState();
654  // Create a new state for each non-eps arc in original FST, out of each bad state.
655  // Not as optimal as it could be, but does avoid some complicated weight-pushing
656  // issues in which, to maintain stochasticity, we would have to know which semiring
657  // we want to maintain stochasticity in.
658  fst->AddArc(newstate, Arc(arc.ilabel, 0, Weight::One(), arc.nextstate));
659  MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
660  maiter.Seek(j);
661  maiter.SetValue(Arc(0, arc.olabel, arc.weight, newstate));
662  }
663  }
664  }
665 }
666 
667 
668 template<class Arc>
669 VectorFst<Arc>* MakeLoopFst(const std::vector<const ExpandedFst<Arc> *> &fsts) {
670  typedef typename Arc::Weight Weight;
671  typedef typename Arc::StateId StateId;
672  typedef typename Arc::Label Label;
673 
674  VectorFst<Arc> *ans = new VectorFst<Arc>;
675  StateId loop_state = ans->AddState(); // = 0.
676  ans->SetStart(loop_state);
677  ans->SetFinal(loop_state, Weight::One());
678 
679  // "cache" is used as an optimization when some of the pointers in "fsts"
680  // may have the same value.
681  unordered_map<const ExpandedFst<Arc> *, Arc> cache;
682 
683  for (Label i = 0; i < static_cast<Label>(fsts.size()); i++) {
684  const ExpandedFst<Arc> *fst = fsts[i];
685  if (fst == NULL) continue;
686  { // optimization with cache: helpful if some members of "fsts" may
687  // contain the same pointer value (e.g. in GetHTransducer).
688  typename unordered_map<const ExpandedFst<Arc> *, Arc>::iterator
689  iter = cache.find(fst);
690  if (iter != cache.end()) {
691  Arc arc = iter->second;
692  arc.olabel = i;
693  ans->AddArc(0, arc);
694  continue;
695  }
696  }
697 
698  KALDI_ASSERT(fst->Properties(kAcceptor, true) == kAcceptor); // expect acceptor.
699 
700  StateId fst_num_states = fst->NumStates();
701  StateId fst_start_state = fst->Start();
702 
703  if (fst_start_state == kNoStateId)
704  continue; // empty fst.
705 
706  bool share_start_state =
707  fst->Properties(kInitialAcyclic, true) == kInitialAcyclic
708  && fst->NumArcs(fst_start_state) == 1
709  && fst->Final(fst_start_state) == Weight::Zero();
710 
711  std::vector<StateId> state_map(fst_num_states); // fst state -> ans state
712  for (StateId s = 0; s < fst_num_states; s++) {
713  if (s == fst_start_state && share_start_state) state_map[s] = loop_state;
714  else state_map[s] = ans->AddState();
715  }
716  if (!share_start_state) {
717  Arc arc(0, i, Weight::One(), state_map[fst_start_state]);
718  cache[fst] = arc;
719  ans->AddArc(0, arc);
720  }
721  for (StateId s = 0; s < fst_num_states; s++) {
722  // Add arcs out of state s.
723  for (ArcIterator<ExpandedFst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
724  const Arc &arc = aiter.Value();
725  Label olabel = (s == fst_start_state && share_start_state ? i : 0);
726  Arc newarc(arc.ilabel, olabel, arc.weight, state_map[arc.nextstate]);
727  ans->AddArc(state_map[s], newarc);
728  if (s == fst_start_state && share_start_state)
729  cache[fst] = newarc;
730  }
731  if (fst->Final(s) != Weight::Zero()) {
732  KALDI_ASSERT(!(s == fst_start_state && share_start_state));
733  ans->AddArc(state_map[s], Arc(0, 0, fst->Final(s), loop_state));
734  }
735  }
736  }
737  return ans;
738 }
739 
740 
741 template<class Arc>
742 void ClearSymbols(bool clear_input,
743  bool clear_output,
744  MutableFst<Arc> *fst) {
745  for (StateIterator<MutableFst<Arc> > siter(*fst);
746  !siter.Done();
747  siter.Next()) {
748  typename Arc::StateId s = siter.Value();
749  for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s);
750  !aiter.Done();
751  aiter.Next()) {
752  Arc arc = aiter.Value();
753  bool change = false;
754  if (clear_input && arc.ilabel != 0) {
755  arc.ilabel = 0;
756  change = true;
757  }
758  if (clear_output && arc.olabel != 0) {
759  arc.olabel = 0;
760  change = true;
761  }
762  if (change) {
763  aiter.SetValue(arc);
764  }
765  }
766  }
767 }
768 
769 
770 template<class Arc>
771 void ApplyProbabilityScale(float scale, MutableFst<Arc> *fst) {
772  typedef typename Arc::Weight Weight;
773  typedef typename Arc::StateId StateId;
774  for (StateIterator<MutableFst<Arc> > siter(*fst);
775  !siter.Done();
776  siter.Next()) {
777  StateId s = siter.Value();
778  for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s);
779  !aiter.Done();
780  aiter.Next()) {
781  Arc arc = aiter.Value();
782  arc.weight = Weight(arc.weight.Value() * scale);
783  aiter.SetValue(arc);
784  }
785  if (fst->Final(s) != Weight::Zero())
786  fst->SetFinal(s, Weight(fst->Final(s).Value() * scale));
787  }
788 }
789 
790 
791 // return arc-offset of self-loop with ilabel (or -1 if none exists).
792 // if more than one such self-loop, pick first one.
793 template<class Arc>
794 ssize_t FindSelfLoopWithILabel(const Fst<Arc> &fst, typename Arc::StateId s) {
795  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next())
796  if (aiter.Value().nextstate == s
797  && aiter.Value().ilabel != 0) return static_cast<ssize_t>(aiter.Position());
798  return static_cast<ssize_t>(-1);
799 }
800 
801 
802 template<class Arc>
803 bool EqualAlign(const Fst<Arc> &ifst,
804  typename Arc::StateId length,
805  int rand_seed,
806  MutableFst<Arc> *ofst,
807  int num_retries) {
808  srand(rand_seed);
809  KALDI_ASSERT(ofst->NumStates() == 0); // make sure ofst empty.
810  // make sure all states can reach final-state (or this algorithm may enter
811  // infinite loop.
812  KALDI_ASSERT(ifst.Properties(kCoAccessible, true) == kCoAccessible);
813 
814  typedef typename Arc::StateId StateId;
815  typedef typename Arc::Weight Weight;
816 
817  if (ifst.Start() == kNoStateId) {
818  KALDI_WARN << "Empty input fst.";
819  return false;
820  }
821  // First select path through ifst.
822  std::vector<StateId> path;
823  std::vector<size_t> arc_offsets; // arc taken out of each state.
824  std::vector<int> nof_ilabels;
825 
826  StateId num_ilabels = 0;
827  int retry_no = 0;
828 
829  // Under normal circumstances, this will be one-pass-only process
830  // Multiple tries might be needed in special cases, typically when
831  // the number of frames is close to number of transitions from
832  // the start node to the final node. It usually happens for really
833  // short utterances
834  do {
835  num_ilabels = 0;
836  arc_offsets.clear();
837  path.clear();
838  path.push_back(ifst.Start());
839 
840  while (1) {
841  // Select either an arc or final-prob.
842  StateId s = path.back();
843  size_t num_arcs = ifst.NumArcs(s);
844  size_t num_arcs_tot = num_arcs;
845  if (ifst.Final(s) != Weight::Zero()) num_arcs_tot++;
846  // kaldi::RandInt is a bit like Rand(), but gets around situations
847  // where RAND_MAX is very small.
848  // Change this to Rand() % num_arcs_tot if compile issues arise
849  size_t arc_offset = static_cast<size_t>(kaldi::RandInt(0, num_arcs_tot-1));
850 
851  if (arc_offset < num_arcs) { // an actual arc.
852  ArcIterator<Fst<Arc> > aiter(ifst, s);
853  aiter.Seek(arc_offset);
854  const Arc &arc = aiter.Value();
855  if (arc.nextstate == s) {
856  continue; // don't take this self-loop arc
857  } else {
858  arc_offsets.push_back(arc_offset);
859  path.push_back(arc.nextstate);
860  if (arc.ilabel != 0) num_ilabels++;
861  }
862  } else {
863  break; // Chose final-prob.
864  }
865  }
866 
867  nof_ilabels.push_back(num_ilabels);
868  } while (( ++retry_no < num_retries) && (num_ilabels > length));
869 
870  if (num_ilabels > length) {
871  std::stringstream ilabel_vec;
872  std::copy(nof_ilabels.begin(), nof_ilabels.end(),
873  std::ostream_iterator<int>(ilabel_vec, ","));
874  std::string s = ilabel_vec.str();
875  s.erase(s.end() - 1);
876  KALDI_WARN << "EqualAlign: the randomly constructed paths lengths: " << s;
877  KALDI_WARN << "EqualAlign: utterance has too few frames " << length
878  << " to align.";
879  return false; // can't make it shorter by adding self-loops!.
880  }
881 
882  StateId num_self_loops = 0;
883  std::vector<ssize_t> self_loop_offsets(path.size());
884  for (size_t i = 0; i < path.size(); i++)
885  if ( (self_loop_offsets[i] = FindSelfLoopWithILabel(ifst, path[i]))
886  != static_cast<ssize_t>(-1) )
887  num_self_loops++;
888 
889  if (num_self_loops == 0
890  && num_ilabels < length) {
891  KALDI_WARN << "No self-loops on chosen path; cannot match length.";
892  return false; // no self-loops to make it longer.
893  }
894 
895  StateId num_extra = length - num_ilabels; // Number of self-loops we need.
896 
897  StateId min_num_loops = 0;
898  if (num_extra != 0) min_num_loops = num_extra / num_self_loops; // prevent div by zero.
899  StateId num_with_one_more_loop = num_extra - (min_num_loops*num_self_loops);
900  KALDI_ASSERT(num_with_one_more_loop < num_self_loops || num_self_loops == 0);
901 
902  ofst->AddState();
903  ofst->SetStart(0);
904  StateId cur_state = 0;
905  StateId counter = 0; // tell us when we should stop adding one more loop.
906  for (size_t i = 0; i < path.size(); i++) {
907  // First, add any self-loops that are necessary.
908  StateId num_loops = 0;
909  if (self_loop_offsets[i] != static_cast<ssize_t>(-1)) {
910  num_loops = min_num_loops + (counter < num_with_one_more_loop ? 1 : 0);
911  counter++;
912  }
913  for (StateId j = 0; j < num_loops; j++) {
914  ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
915  aiter.Seek(self_loop_offsets[i]);
916  Arc arc = aiter.Value();
917  KALDI_ASSERT(arc.nextstate == path[i]
918  && arc.ilabel != 0); // make sure self-loop with ilabel.
919  StateId next_state = ofst->AddState();
920  ofst->AddArc(cur_state, Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
921  cur_state = next_state;
922  }
923  if (i+1 < path.size()) { // add forward transition.
924  ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
925  aiter.Seek(arc_offsets[i]);
926  Arc arc = aiter.Value();
927  KALDI_ASSERT(arc.nextstate == path[i+1]);
928  StateId next_state = ofst->AddState();
929  ofst->AddArc(cur_state, Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
930  cur_state = next_state;
931  } else { // add final-prob.
932  Weight weight = ifst.Final(path[i]);
933  KALDI_ASSERT(weight != Weight::Zero());
934  ofst->SetFinal(cur_state, weight);
935  }
936  }
937  return true;
938 }
939 
940 
941 // This function identifies two types of useless arcs:
942 // those where arc A and arc B both go from state X to
943 // state Y with the same input symbol (remove the one
944 // with smaller probability, or an arbitrary one if they
945 // are the same); and those where A is an arc from state X
946 // to state X, with epsilon input symbol [remove A].
947 // Only works for tropical (not log) semiring as it uses
948 // NaturalLess.
949 template<class Arc>
950 void RemoveUselessArcs(MutableFst<Arc> *fst) {
951  typedef typename Arc::Label Label;
952  typedef typename Arc::StateId StateId;
953  typedef typename Arc::Weight Weight;
954  NaturalLess<Weight> nl;
955  StateId non_coacc_state = kNoStateId;
956  size_t num_arcs_removed = 0, tot_arcs = 0;
957  for (StateIterator<MutableFst<Arc> > siter(*fst);
958  !siter.Done();
959  siter.Next()) {
960  std::vector<size_t> arcs_to_delete;
961  std::vector<Arc> arcs;
962  // pair2arclist lets us look up the arcs
963  std::map<std::pair<Label, StateId>, std::vector<size_t> > pair2arclist;
964  StateId state = siter.Value();
965  for (ArcIterator<MutableFst<Arc> > aiter(*fst, state);
966  !aiter.Done();
967  aiter.Next()) {
968  size_t pos = arcs.size();
969  const Arc &arc = aiter.Value();
970  arcs.push_back(arc);
971  pair2arclist[std::make_pair(arc.ilabel, arc.nextstate)].push_back(pos);
972  }
973  typename std::map<std::pair<Label, StateId>, std::vector<size_t> >::iterator
974  iter = pair2arclist.begin(), end = pair2arclist.end();
975  for (; iter!= end; ++iter) {
976  const std::vector<size_t> &poslist = iter->second;
977  if (poslist.size() > 1) { // >1 arc with same ilabel, dest-state
978  size_t best_pos = poslist[0];
979  Weight best_weight = arcs[best_pos].weight;
980  for (size_t j = 1; j < poslist.size(); j++) {
981  size_t pos = poslist[j];
982  Weight this_weight = arcs[pos].weight;
983  if (nl(this_weight, best_weight)) { // NaturalLess seems to be somehow
984  // "backwards".
985  best_weight = this_weight; // found a better one.
986  best_pos = pos;
987  }
988  }
989  for (size_t j = 0; j < poslist.size(); j++)
990  if (poslist[j] != best_pos)
991  arcs_to_delete.push_back(poslist[j]);
992  } else {
993  KALDI_ASSERT(poslist.size() == 1);
994  size_t pos = poslist[0];
995  Arc &arc = arcs[pos];
996  if (arc.ilabel == 0 && arc.nextstate == state)
997  arcs_to_delete.push_back(pos);
998  }
999  }
1000  tot_arcs += arcs.size();
1001  if (arcs_to_delete.size() != 0) {
1002  num_arcs_removed += arcs_to_delete.size();
1003  if (non_coacc_state == kNoStateId)
1004  non_coacc_state = fst->AddState();
1005  MutableArcIterator<MutableFst<Arc> > maiter(fst, state);
1006  for (size_t j = 0; j < arcs_to_delete.size(); j++) {
1007  size_t pos = arcs_to_delete[j];
1008  maiter.Seek(pos);
1009  arcs[pos].nextstate = non_coacc_state;
1010  maiter.SetValue(arcs[pos]);
1011  }
1012  }
1013  }
1014  if (non_coacc_state != kNoStateId)
1015  Connect(fst);
1016  KALDI_VLOG(1) << "removed " << num_arcs_removed << " of " << tot_arcs
1017  << "arcs.";
1018 }
1019 
1020 template<class Arc>
1021 void PhiCompose(const Fst<Arc> &fst1,
1022  const Fst<Arc> &fst2,
1023  typename Arc::Label phi_label,
1024  MutableFst<Arc> *ofst) {
1025  KALDI_ASSERT(phi_label != kNoLabel); // just use regular compose in this case.
1026  typedef Fst<Arc> F;
1027  typedef PhiMatcher<SortedMatcher<F> > PM;
1028  CacheOptions base_opts;
1029  base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
1030  // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
1031  // The matcher for fst1 doesn't matter; we'll use fst2's matcher.
1032  ComposeFstImplOptions<SortedMatcher<F>, PM> impl_opts(base_opts);
1033 
1034  // the false below is something called phi_loop which is something I don't
1035  // fully understand, but I don't think we want it.
1036 
1037  // These pointers are taken ownership of, by ComposeFst.
1038  PM *phi_matcher =
1039  new PM(fst2, MATCH_INPUT, phi_label, false);
1040  SortedMatcher<F> *sorted_matcher =
1041  new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
1042  // not to use this matcher, as this would mean we would
1043  // not follow phi transitions.
1044  impl_opts.matcher1 = sorted_matcher;
1045  impl_opts.matcher2 = phi_matcher;
1046  *ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
1047  Connect(ofst);
1048 }
1049 
1050 template<class Arc>
1052  typename Arc::Label phi_label,
1053  typename Arc::StateId s,
1054  MutableFst<Arc> *fst) {
1055  typedef typename Arc::Weight Weight;
1056  if (fst->Final(s) == Weight::Zero()) {
1057  // search for phi transition. We assume there
1058  // is just one-- phi nondeterminism is not allowed
1059  // anyway.
1060  int num_phis = 0;
1061  for (ArcIterator<Fst<Arc> > aiter(*fst, s);
1062  !aiter.Done(); aiter.Next()) {
1063  const Arc &arc = aiter.Value();
1064  if (arc.ilabel == phi_label) {
1065  num_phis++;
1066  if (arc.nextstate == s) continue; // don't expect
1067  // phi loops but ignore them anyway.
1068 
1069  // If this recurses infinitely, it means there
1070  // are loops of phi transitions, which there should
1071  // not be in a normal backoff LM. We could make this
1072  // routine work for this case, but currently there is
1073  // no need.
1074  PropagateFinalInternal(phi_label, arc.nextstate, fst);
1075  if (fst->Final(arc.nextstate) != Weight::Zero())
1076  fst->SetFinal(s, Times(fst->Final(arc.nextstate), arc.weight));
1077  }
1078  KALDI_ASSERT(num_phis <= 1 && "Phi nondeterminism found");
1079  }
1080  }
1081 }
1082 
1083 template<class Arc>
1084 void PropagateFinal(typename Arc::Label phi_label,
1085  MutableFst<Arc> *fst) {
1086  typedef typename Arc::StateId StateId;
1087  if (fst->Properties(kIEpsilons, true)) // just warn.
1088  KALDI_WARN << "PropagateFinal: this may not work as desired "
1089  "since your FST has input epsilons.";
1090  StateId num_states = fst->NumStates();
1091  for (StateId s = 0; s < num_states; s++)
1092  PropagateFinalInternal(phi_label, s, fst);
1093 }
1094 
1095 template<class Arc>
1096 void RhoCompose(const Fst<Arc> &fst1,
1097  const Fst<Arc> &fst2,
1098  typename Arc::Label rho_label,
1099  MutableFst<Arc> *ofst) {
1100  KALDI_ASSERT(rho_label != kNoLabel); // just use regular compose in this case.
1101  typedef Fst<Arc> F;
1102  typedef RhoMatcher<SortedMatcher<F> > RM;
1103  CacheOptions base_opts;
1104  base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
1105  // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
1106  // The matcher for fst1 doesn't matter; we'll use fst2's matcher.
1107  ComposeFstImplOptions<SortedMatcher<F>, RM> impl_opts(base_opts);
1108 
1109  // the false below is something called rho_loop which is something I don't
1110  // fully understand, but I don't think we want it.
1111 
1112  // These pointers are taken ownership of, by ComposeFst.
1113  RM *rho_matcher =
1114  new RM(fst2, MATCH_INPUT, rho_label);
1115  SortedMatcher<F> *sorted_matcher =
1116  new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
1117  // not to use this matcher, as this would mean we would
1118  // not follow rho transitions.
1119  impl_opts.matcher1 = sorted_matcher;
1120  impl_opts.matcher2 = rho_matcher;
1121  *ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
1122  Connect(ofst);
1123 }
1124 
1125 
1126 // Declare an override of the template below.
1127 template<>
1128 inline bool IsStochasticFst(const Fst<LogArc> &fst,
1129  float delta,
1130  LogArc::Weight *min_sum,
1131  LogArc::Weight *max_sum);
1132 
1133 // Will override this for LogArc where NaturalLess will not work.
1134 template<class Arc>
1135 inline bool IsStochasticFst(const Fst<Arc> &fst,
1136  float delta,
1137  typename Arc::Weight *min_sum,
1138  typename Arc::Weight *max_sum) {
1139  typedef typename Arc::StateId StateId;
1140  typedef typename Arc::Weight Weight;
1141  NaturalLess<Weight> nl;
1142  bool first_time = true;
1143  bool ans = true;
1144  if (min_sum) *min_sum = Arc::Weight::One();
1145  if (max_sum) *max_sum = Arc::Weight::One();
1146  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
1147  StateId s = siter.Value();
1148  Weight sum = fst.Final(s);
1149  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
1150  const Arc &arc = aiter.Value();
1151  sum = Plus(sum, arc.weight);
1152  }
1153  if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
1154  if (first_time) {
1155  first_time = false;
1156  if (max_sum) *max_sum = sum;
1157  if (min_sum) *min_sum = sum;
1158  } else {
1159  if (max_sum && nl(*max_sum, sum)) *max_sum = sum;
1160  if (min_sum && nl(sum, *min_sum)) *min_sum = sum;
1161  }
1162  }
1163  if (first_time) { // just avoid NaNs if FST was empty.
1164  if (max_sum) *max_sum = Weight::One();
1165  if (min_sum) *min_sum = Weight::One();
1166  }
1167  return ans;
1168 }
1169 
1170 
1171 // Overriding template for LogArc as NaturalLess does not work there.
1172 template<>
1173 inline bool IsStochasticFst(const Fst<LogArc> &fst,
1174  float delta,
1175  LogArc::Weight *min_sum,
1176  LogArc::Weight *max_sum) {
1177  typedef LogArc Arc;
1178  typedef Arc::StateId StateId;
1179  typedef Arc::Weight Weight;
1180  bool first_time = true;
1181  bool ans = true;
1182  if (min_sum) *min_sum = LogArc::Weight::One();
1183  if (max_sum) *max_sum = LogArc::Weight::One();
1184  for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
1185  StateId s = siter.Value();
1186  Weight sum = fst.Final(s);
1187  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
1188  const Arc &arc = aiter.Value();
1189  sum = Plus(sum, arc.weight);
1190  }
1191  if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
1192  if (first_time) {
1193  first_time = false;
1194  if (max_sum) *max_sum = sum;
1195  if (min_sum) *min_sum = sum;
1196  } else {
1197  // note that max and min are reversed from their normal
1198  // meanings here (max and min w.r.t. the underlying probabilities).
1199  if (max_sum && sum.Value() < max_sum->Value()) *max_sum = sum;
1200  if (min_sum && sum.Value() > min_sum->Value()) *min_sum = sum;
1201  }
1202  }
1203  if (first_time) { // just avoid NaNs if FST was empty.
1204  if (max_sum) *max_sum = Weight::One();
1205  if (min_sum) *min_sum = Weight::One();
1206  }
1207  return ans;
1208 }
1209 
1210 // Tests whether a tropical FST is stochastic in the log
1211 // semiring. (casts it and does the check.)
1212 // This function deals with the generic fst.
1213 // This version currently supports ConstFst<StdArc> or VectorFst<StdArc>.
1214 // Otherwise, it will be died with an error.
1215 inline bool IsStochasticFstInLog(const Fst<StdArc> &fst,
1216  float delta,
1217  StdArc::Weight *min_sum,
1218  StdArc::Weight *max_sum) {
1219  bool ans = false;
1220  LogArc::Weight log_min = LogArc::Weight::One(),
1221  log_max = LogArc::Weight::Zero();
1222  if (fst.Type() == "const") {
1223  ConstFst<LogArc> logfst;
1224  Cast(dynamic_cast<const ConstFst<StdArc>&>(fst), &logfst);
1225  ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
1226  } else if (fst.Type() == "vector") {
1227  VectorFst<LogArc> logfst;
1228  Cast(dynamic_cast<const VectorFst<StdArc>&>(fst), &logfst);
1229  ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
1230  } else {
1231  KALDI_ERR << "This version currently supports ConstFst<StdArc> "
1232  << "or VectorFst<StdArc>";
1233  }
1234  if (min_sum) *min_sum = StdArc::Weight(log_min.Value());
1235  if (max_sum) *max_sum = StdArc::Weight(log_max.Value());
1236  return ans;
1237 }
1238 
1239 } // namespace fst.
1240 
1241 #endif
fst::StdArc::StateId StateId
void CopySetToVector(const std::set< T > &s, std::vector< T > *v)
Copies the elements of a set to a vector.
Definition: stl-utils.h:86
void GetSymbols(const SymbolTable &symtab, bool include_eps, std::vector< I > *syms_out)
const std::vector< I > * symbol_mapping_
#define KALDI_ASSERT_IS_INTEGER_TYPE(I)
Definition: kaldi-utils.h:133
ssize_t FindSelfLoopWithILabel(const Fst< Arc > &fst, typename Arc::StateId s)
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
void MakePrecedingInputSymbolsSame(bool start_is_epsilon, MutableFst< Arc > *fst)
MakePrecedingInputSymbolsSame ensures that all arcs entering any given fst state have the same input ...
void PreDeterminize(MutableFst< Arc > *fst, typename Arc::Label first_new_sym, std::vector< Int > *symsOut)
void PropagateFinal(typename Arc::Label phi_label, MutableFst< Arc > *fst)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
void PhiCompose(const Fst< Arc > &fst1, const Fst< Arc > &fst2, typename Arc::Label phi_label, MutableFst< Arc > *ofst)
void DeterminizeInLog(VectorFst< StdArc > *fst)
void MinimizeEncoded(VectorFst< Arc > *fst, float delta=kDelta)
Definition: fstext-utils.h:114
MapSymbolsAction InputSymbolsAction() const
bool ApproxEqual(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, float delta=kDelta)
void ClearSymbols(bool clear_input, bool clear_output, MutableFst< Arc > *fst)
ClearSymbols sets all the symbols on the input and/or output side of the FST to zero, as specified.
void SafeDeterminizeWrapperInLog(VectorFst< StdArc > *ifst, VectorFst< StdArc > *ofst, float delta)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void GetInputSymbols(const Fst< Arc > &fst, bool include_eps, std::vector< I > *symbols)
GetInputSymbols gets the list of symbols on the input of fst (including epsilon, if include_eps == tr...
void RemoveUselessArcs(MutableFst< Arc > *fst)
MapInputSymbolsMapper(const std::vector< I > &to_remove, bool copy)
void MakeLinearAcceptorWithAlternatives(const std::vector< std::vector< I > > &labels, MutableFst< Arc > *ofst)
Creates an unweighted acceptor with a linear structure, with alternatives at each position...
MapFinalAction FinalAction() const
void DeterminizeStarInLog(VectorFst< StdArc > *fst, float delta, bool *debug_ptr, int max_states)
void MakeLinearAcceptor(const std::vector< I > &labels, MutableFst< Arc > *ofst)
Creates unweighted linear acceptor from symbol sequence.
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
bool EqualAlign(const Fst< Arc > &ifst, typename Arc::StateId length, int rand_seed, MutableFst< Arc > *ofst, int num_retries)
EqualAlign is similar to RandGen, but it generates a sequence with exactly "length" input symbols...
uint64 Properties(uint64 props) const
void ConvertNbestToVector(const Fst< Arc > &fst, std::vector< VectorFst< Arc > > *fsts_out)
This function converts an FST with a special structure, which is output by the OpenFst functions Shor...
void NbestAsFsts(const Fst< Arc > &fst, size_t n, std::vector< VectorFst< Arc > > *fsts_out)
Takes the n-shortest-paths (using ShortestPath), but outputs the result as a vector of up to n fsts...
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:147
Arc::Label HighestNumberedOutputSymbol(const Fst< Arc > &fst)
Returns the highest numbered output symbol id of the FST (or zero for an empty FST.
bool IsStochasticFstInLog(const Fst< StdArc > &fst, float delta, StdArc::Weight *min_sum, StdArc::Weight *max_sum)
#define KALDI_WARN
Definition: kaldi-error.h:150
MapSymbolsAction OutputSymbolsAction() const
VectorFst< Arc > * MakeLoopFst(const std::vector< const ExpandedFst< Arc > *> &fsts)
MakeLoopFst creates an FST that has a state that is both initial and final (weight == Weight::One())...
bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst< Arc > &fst)
Returns true if and only if the FST is such that the input symbols on arcs entering any given state a...
bool IsStochasticFst(const Fst< LogArc > &fst, float delta, LogArc::Weight *min_sum, LogArc::Weight *max_sum)
fst::StdArc::Label Label
void ApplyProbabilityScale(float scale, MutableFst< Arc > *fst)
ApplyProbabilityScale is applicable to FSTs in the log or tropical semiring.
bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon, const Fst< Arc > &fst, const F &f)
This is as PrecedingInputSymbolsAreSame, but with a functor f that maps labels to classes...
void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon, MutableFst< Arc > *fst, const F &f)
As MakePrecedingInputSymbolsSame, but takes a functor object that maps labels to classes.
fst::StdArc::Weight Weight
Arc operator()(const Arc &arc_in)
bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst< Arc > &fst)
Returns true if and only if the FST is such that the input symbols on arcs exiting any given state al...
Arc::Label HighestNumberedInputSymbol(const Fst< Arc > &fst)
Returns the highest numbered input symbol id of the FST (or zero for an empty FST.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void MapInputSymbols(const std::vector< I > &symbol_mapping, MutableFst< Arc > *fst)
void RemoveWeights(MutableFst< Arc > *ifst)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void SafeDeterminizeWrapper(MutableFst< Arc > *ifst, MutableFst< Arc > *ofst, float delta)
Does PreDeterminize and DeterminizeStar and then removes the disambiguation symbols.
void PropagateFinalInternal(typename Arc::Label phi_label, typename Arc::StateId s, MutableFst< Arc > *fst)
void SafeDeterminizeMinimizeWrapperInLog(VectorFst< StdArc > *ifst, VectorFst< StdArc > *ofst, float delta)
SafeDeterminizeMinimizeWapperInLog is as SafeDeterminizeMinimizeWrapper except it first casts tothe l...
void SafeDeterminizeMinimizeWrapper(MutableFst< Arc > *ifst, VectorFst< Arc > *ofst, float delta)
SafeDeterminizeMinimizeWapper is as SafeDeterminizeWrapper except that it also minimizes (encoded min...
void RhoCompose(const Fst< Arc > &fst1, const Fst< Arc > &fst2, typename Arc::Label rho_label, MutableFst< Arc > *ofst)
bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst< Arc > &fst, const F &f)
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon, MutableFst< Arc > *fst, const F &f)
As MakeFollowingInputSymbolsSame, but takes a functor object that maps labels to classes.
void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst< Arc > *fst)
MakeFollowingInputSymbolsSame ensures that all arcs exiting any given fst state have the same input s...
void GetOutputSymbols(const Fst< Arc > &fst, bool include_eps, std::vector< I > *symbols)
GetOutputSymbols gets the list of symbols on the output of fst (including epsilon, if include_eps == true)
bool DeterminizeStar(F &ifst, MutableFst< typename F::Arc > *ofst, float delta, bool *debug_ptr, int max_states, bool allow_partial)
This function implements the normal version of DeterminizeStar, in which the output strings are repre...
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void RemoveSomeInputSymbols(const std::vector< I > &to_remove, MutableFst< Arc > *fst)
RemoveSomeInputSymbols removes any symbol that appears in "to_remove", from the input side of the FST...