remove-eps-local-test.cc
Go to the documentation of this file.
1 // fstext/remove-eps-local-test.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
22 #include "fstext/fstext-utils.h"
23 #include "fstext/fst-test-utils.h"
24 #include "base/kaldi-math.h"
25 
26 
27 namespace fst
28 {
29 using std::vector;
30 using std::cout;
31 
32 // Don't instantiate with log semiring, as RandEquivalent may fail.
33 template<class Arc> static void TestRemoveEpsLocal() {
34  typedef typename Arc::Label Label;
35  typedef typename Arc::StateId StateId;
36  typedef typename Arc::Weight Weight;
37 
38  VectorFst<Arc> fst;
39  int n_syms = 2 + kaldi::Rand() % 5, n_arcs = 5 + kaldi::Rand() % 30, n_final = 1 + kaldi::Rand()%10;
40 
41  SymbolTable symtab("my-symbol-table"), *sptr = &symtab;
42 
43  vector<Label> all_syms; // including epsilon.
44  // Put symbols in the symbol table from 1..n_syms-1.
45  for (size_t i = 0;i < (size_t)n_syms;i++) {
46  std::stringstream ss;
47  if (i == 0) ss << "<eps>";
48  else ss<<i;
49  Label cur_lab = sptr->AddSymbol(ss.str());
50  assert(cur_lab == (Label)i);
51  all_syms.push_back(cur_lab);
52  }
53  assert(all_syms[0] == 0);
54 
55  fst.AddState();
56  int cur_num_states = 1;
57  for (int i = 0; i < n_arcs; i++) {
58  StateId src_state = kaldi::Rand() % cur_num_states;
59  StateId dst_state;
60  if (kaldi::RandUniform() < 0.1) dst_state = kaldi::Rand() % cur_num_states;
61  else {
62  dst_state = cur_num_states++; fst.AddState();
63  }
64  Arc arc;
65  if (kaldi::RandUniform() < 0.3) arc.ilabel = all_syms[kaldi::Rand()%all_syms.size()];
66  else arc.ilabel = 0;
67  if (kaldi::RandUniform() < 0.3) arc.olabel = all_syms[kaldi::Rand()%all_syms.size()];
68  else arc.olabel = 0;
69  arc.weight = (Weight) (0 + 0.1*(kaldi::Rand() % 5));
70  arc.nextstate = dst_state;
71  fst.AddArc(src_state, arc);
72  }
73  for (int i = 0; i < n_final; i++) {
74  fst.SetFinal(kaldi::Rand() % cur_num_states, (Weight) (0 + 0.1*(kaldi::Rand() % 5)));
75  }
76 
77  if (kaldi::RandUniform() < 0.8) fst.SetStart(0); // usually leads to nicer examples.
78  else fst.SetStart(kaldi::Rand() % cur_num_states);
79 
80  Connect(&fst);
81  if (fst.Start() == kNoStateId) return; // "Connect" made it empty.
82 
83  std::cout <<" printing after trimming\n";
84  {
85  FstPrinter<Arc> fstprinter(fst, sptr, sptr, NULL, false, true, "\t");
86  fstprinter.Print(&std::cout, "standard output");
87  }
88 
89  VectorFst<Arc> fst_copy1(fst);
90 
91 
92  RemoveEpsLocal(&fst_copy1);
93 
94 
95 
96  {
97  std::cout << "copy1 = \n";
98  FstPrinter<Arc> fstprinter(fst_copy1, sptr, sptr, NULL, false, true, "\t");
99  fstprinter.Print(&std::cout, "standard output");
100  }
101 
102 
103  int num_states_0 = fst.NumStates();
104  int num_states_1 = fst_copy1.NumStates();
105 
106 
107  std::cout << "Number of states 0 = "<<num_states_0<<", 1 = "<<num_states_1<<'\n';
108 
109  assert(RandEquivalent(fst, fst_copy1, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
110 }
111 
112 
114  // test that RemoveEpsLocalSpecial preserves equivalence in tropical while
115  // maintaining stochasticity in log.
116  typedef VectorFst<LogArc> Fst;
117  typedef LogArc::Weight Weight;
118  typedef LogArc::StateId StateId;
119  typedef LogArc Arc;
120  VectorFst<LogArc> *logfst = RandFst<LogArc>();
121 
122  { // Make the FST stochastic.
123  for (StateId s = 0; s < logfst->NumStates(); s++) {
124  Weight w = logfst->Final(s);
125  for (ArcIterator<Fst> aiter(*logfst, s); !aiter.Done(); aiter.Next()) {
126  w = Plus(w, aiter.Value().weight);
127  }
128  if (w != Weight::Zero()) {
129  logfst->SetFinal(s, Divide(logfst->Final(s), w, DIVIDE_ANY));
130  for (MutableArcIterator<Fst> aiter(logfst, s); !aiter.Done(); aiter.Next()) {
131  Arc a = aiter.Value();
132  a.weight = Divide(a.weight, w, DIVIDE_ANY);
133  aiter.SetValue(a);
134  }
135  }
136  }
137  }
138 #ifndef _MSC_VER
139  assert(IsStochasticFst(*logfst, kDelta*10));
140 #endif
141  {
142  std::cout << "logfst = \n";
143  FstPrinter<LogArc> fstprinter(*logfst, NULL, NULL, NULL, false, true, "\t");
144  fstprinter.Print(&std::cout, "standard output");
145  }
146 
147  VectorFst<StdArc> fst;
148  Cast(*logfst, &fst);
149  VectorFst<StdArc> fst_copy(fst);
150  RemoveEpsLocalSpecial(&fst); // removes eps in std-arc but keep stochastic in log-arc
151  // make sure equivalent.
152  assert(RandEquivalent(fst, fst_copy, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
153  VectorFst<LogArc> logfst2;
154  Cast(fst, &logfst2);
155 
156  {
157  std::cout << "logfst2 = \n";
158  FstPrinter<LogArc> fstprinter(logfst2, NULL, NULL, NULL, false, true, "\t");
159  fstprinter.Print(&std::cout, "standard output");
160  }
161  if (ApproxEqual(ShortestDistance(*logfst), ShortestDistance(logfst2))) {
162  // make sure we preserved stochasticity in cases where doing so was
163  // possible... if the log-semiring total weight changed, then it is
164  // not possible so don't assert this.
165  assert(IsStochasticFst(logfst2, kDelta*10));
166  }
167  delete logfst;
168 }
169 
170 } // namespace fst
171 
172 int main() {
173  using namespace fst;
174  for (int i = 0; i < 10; i++) {
175  TestRemoveEpsLocal<fst::StdArc>();
177  }
178 }
fst::StdArc::StateId StateId
int main()
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
float RandUniform(struct RandomState *state=NULL)
Returns a random number strictly between 0 and 1.
Definition: kaldi-math.h:151
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
bool ApproxEqual(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, float delta=kDelta)
void RemoveEpsLocalSpecial(MutableFst< StdArc > *fst)
As RemoveEpsLocal but takes care to preserve stochasticity when cast to LogArc.
static void TestRemoveEpsLocalSpecial()
bool IsStochasticFst(const Fst< LogArc > &fst, float delta, LogArc::Weight *min_sum, LogArc::Weight *max_sum)
fst::StdArc::Label Label
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
fst::StdArc::Weight Weight
static void TestRemoveEpsLocal()