determinize-star-test.cc
Go to the documentation of this file.
1 // fstext/determinize-star-test.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 // 2015 Hainan Xu
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-math.h"
22 #include "fstext/pre-determinize.h"
25 #include "fstext/fst-test-utils.h"
26 
27 
28 namespace fst
29 {
30 
31 // test that determinization proceeds correctly on general
32 // FSTs (not guaranteed determinzable, but we use the
33 // max-states option to stop it getting out of control).
34 template<class Arc> void TestDeterminizeGeneral() {
35  int max_states = 100; // don't allow more det-states than this.
36  for(int i = 0; i < 100; i++) {
37  VectorFst<Arc> *fst = RandFst<Arc>();
38  std::cout << "FST before determinizing is:\n";
39  {
40  FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
41  fstprinter.Print(&std::cout, "standard output");
42  }
43  VectorFst<Arc> ofst;
44  try {
45  DeterminizeStar<Fst<Arc> >(*fst, &ofst, kDelta, NULL, max_states);
46  std::cout << "FST after determinizing is:\n";
47  {
48  FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
49  fstprinter.Print(&std::cout, "standard output");
50  }
51  assert(RandEquivalent(*fst, ofst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
52  } catch (...) {
53  std::cout << "Failed to determinize *this FST (probably not determinizable)\n";
54  }
55  delete fst;
56  }
57 }
58 
59 
60 // Don't instantiate with log semiring, as RandEquivalent may fail.
61 template<class Arc> void TestDeterminize() {
62  typedef typename Arc::Label Label;
63  typedef typename Arc::StateId StateId;
64  typedef typename Arc::Weight Weight;
65 
66  VectorFst<Arc> *fst = new VectorFst<Arc>();
67  int n_syms = 2 + kaldi::Rand() % 5, n_states = 3 + kaldi::Rand() % 10, n_arcs = 5 + kaldi::Rand() % 30, n_final = 1 + kaldi::Rand()%3; // Up to 2 unique symbols.
68  std::cout << "Testing pre-determinize with "<<n_syms<<" symbols, "<<n_states<<" states and "<<n_arcs<<" arcs and "<<n_final<<" final states.\n";
69  SymbolTable *sptr = NULL;
70 
71  std::vector<Label> all_syms; // including epsilon.
72  // Put symbols in the symbol table from 1..n_syms-1.
73  for (size_t i = 0;i < (size_t)n_syms;i++)
74  all_syms.push_back(i);
75 
76  // Create states.
77  std::vector<StateId> all_states;
78  for (size_t i = 0;i < (size_t)n_states;i++) {
79  StateId this_state = fst->AddState();
80  if (i == 0) fst->SetStart(i);
81  all_states.push_back(this_state);
82  }
83  // Set final states.
84  for (size_t j = 0;j < (size_t)n_final;j++) {
85  StateId id = all_states[kaldi::Rand() % n_states];
86  Weight weight = (Weight)(0.33*(kaldi::Rand() % 5) );
87  printf("calling SetFinal with %d and %f\n", id, weight.Value());
88  fst->SetFinal(id, weight);
89  }
90  // Create arcs.
91  for (size_t i = 0;i < (size_t)n_arcs;i++) {
92  Arc a;
93  a.nextstate = all_states[kaldi::Rand() % n_states];
94  a.ilabel = all_syms[kaldi::Rand() % n_syms];
95  a.olabel = all_syms[kaldi::Rand() % n_syms]; // same input+output vocab.
96  a.weight = (Weight) (0.33*(kaldi::Rand() % 2));
97  StateId start_state = all_states[kaldi::Rand() % n_states];
98  fst->AddArc(start_state, a);
99  }
100 
101  std::cout <<" printing before trimming\n";
102  {
103  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
104  fstprinter.Print(&std::cout, "standard output");
105  }
106  // Trim resulting FST.
107  Connect(fst);
108 
109  std::cout <<" printing after trimming\n";
110  {
111  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
112  fstprinter.Print(&std::cout, "standard output");
113  }
114 
115  VectorFst<Arc> *fst_copy_orig = new VectorFst<Arc>(*fst);
116 
117  std::vector<Label> extra_syms;
118  if (fst->Start() != kNoStateId) { // "Connect" did not make it empty....
119  PreDeterminize(fst, 1000, &extra_syms);
120  }
121 
122  std::cout <<" printing after predeterminization\n";
123  {
124  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
125  fstprinter.Print(&std::cout, "standard output");
126  }
127 
128 
129  { // Remove epsilon. All default args.
130  bool connect = true;
131  Weight weight_threshold = Weight::Zero();
132  int64 nstate = -1; // Relates to pruning.
133  double delta = kDelta; // I think a small weight value. Relates to some kind of pruning,
134  // I guess. But with no epsilon cycles, probably doensn't matter.
135  RmEpsilon(fst, connect, weight_threshold, nstate, delta);
136  }
137 
138  std::cout <<" printing after epsilon removal\n";
139  {
140  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
141  fstprinter.Print(&std::cout, "standard output");
142  }
143  VectorFst<Arc> ofst_orig;
144  VectorFst<Arc> ofst_star;
145 
146  {
147  printf("Determinizing with baseline\n");
148  DeterminizeOptions<Arc> opts; // Default options.
149  Determinize(*fst, &ofst_orig, opts);
150  }
151 
152  {
153  printf("Determinizing with DeterminizeStar\n");
154  DeterminizeStar(*fst, &ofst_star);
155  }
156 
157  {
158  std::cout <<" printing after determinization [baseline]\n";
159  FstPrinter<Arc> fstprinter(ofst_orig, sptr, sptr, NULL, false, true, "\t");
160  fstprinter.Print(&std::cout, "standard output");
161  assert(ofst_orig.Properties(kIDeterministic, true) == kIDeterministic);
162  }
163 
164  {
165  std::cout <<" printing after determinization [star]\n";
166  FstPrinter<Arc> fstprinter(ofst_star, sptr, sptr, NULL, false, true, "\t");
167  fstprinter.Print(&std::cout, "standard output");
168  assert(ofst_star.Properties(kIDeterministic, true) == kIDeterministic);
169  }
170 
171  assert(RandEquivalent(ofst_orig, ofst_star, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
172 
173  int64 num_removed = DeleteISymbols(&ofst_star, extra_syms);
174  std::cout <<" printing after removing "<<num_removed<<" instances of extra symbols\n";
175  {
176  FstPrinter<Arc> fstprinter(ofst_star, sptr, sptr, NULL, false, true, "\t");
177  fstprinter.Print(&std::cout, "standard output");
178  }
179 
180  std::cout <<" Checking equivalent to original FST.\n";
181  // giving Rand() as a seed stops the random number generator from always being reset to
182  // the same point each time, while maintaining determinism of the test.
183  assert(RandEquivalent(ofst_star, *fst_copy_orig, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
184 
185  delete fst;
186  delete fst_copy_orig;
187 }
188 
189 // Don't call this-- the test will fail due to the FST being non-functional.
190 template<class Arc> void TestDeterminize2() {
191  for(int i = 0; i < 10; i++) {
192  RandFstOptions opts;
193  opts.acyclic = true;
194  VectorFst<Arc> *ifst = RandFst<Arc>(opts);
195  VectorFst<Arc> ofst;
196  Determinize(*ifst, &ofst);
197  assert(RandEquivalent(*ifst, ofst, 5, 0.01, kaldi::Rand(), 100));
198  delete ifst;
199  }
200 }
201 
202 template<class Arc> void TestPush() {
203  typedef typename Arc::Label Label;
204  typedef typename Arc::StateId StateId;
205  typedef typename Arc::Weight Weight;
206 
207  VectorFst<Arc> *fst = new VectorFst<Arc>();
208  int n_syms = 2 + kaldi::Rand() % 5, n_states = 3 + kaldi::Rand() % 10, n_arcs = 5 + kaldi::Rand() % 30, n_final = 1 + kaldi::Rand()%3; // Up to 2 unique symbols.
209  std::cout << "Testing pre-determinize with "<<n_syms<<" symbols, "<<n_states<<" states and "<<n_arcs<<" arcs and "<<n_final<<" final states.\n";
210  SymbolTable *sptr = NULL;
211 
212  std::vector<Label> all_syms; // including epsilon.
213  // Put symbols in the symbol table from 1..n_syms-1.
214  for (size_t i = 0;i < (size_t)n_syms;i++)
215  all_syms.push_back(i);
216 
217  // Create states.
218  std::vector<StateId> all_states;
219  for (size_t i = 0;i < (size_t)n_states;i++) {
220  StateId this_state = fst->AddState();
221  if (i == 0) fst->SetStart(i);
222  all_states.push_back(this_state);
223  }
224  // Set final states.
225  for (size_t j = 0;j < (size_t)n_final;j++) {
226  StateId id = all_states[kaldi::Rand() % n_states];
227  Weight weight = (Weight)(0.33*(kaldi::Rand() % 5) );
228  printf("calling SetFinal with %d and %f\n", id, weight.Value());
229  fst->SetFinal(id, weight);
230  }
231  // Create arcs.
232  for (size_t i = 0;i < (size_t)n_arcs;i++) {
233  Arc a;
234  a.nextstate = all_states[kaldi::Rand() % n_states];
235  a.ilabel = all_syms[kaldi::Rand() % n_syms];
236  a.olabel = all_syms[kaldi::Rand() % n_syms]; // same input+output vocab.
237  a.weight = (Weight) (0.33*(kaldi::Rand() % 2));
238  StateId start_state = all_states[kaldi::Rand() % n_states];
239  fst->AddArc(start_state, a);
240  }
241 
242  std::cout <<" printing before trimming\n";
243  {
244  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
245  fstprinter.Print(&std::cout, "standard output");
246  }
247  // Trim resulting FST.
248  Connect(fst);
249 
250  std::cout <<" printing after trimming\n";
251  {
252  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
253  fstprinter.Print(&std::cout, "standard output");
254  }
255 
256  VectorFst<Arc> *fst_copy_orig = new VectorFst<Arc>(*fst);
257 
258  std::vector<Label> extra_syms;
259  if (fst->Start() != kNoStateId) { // "Connect" did not make it empty....
260  PreDeterminize(fst, 1000, &extra_syms);
261  }
262 
263  VectorFst<Arc> fst_pushed;
264  std::cout << "Pushing FST\n";
265  Push<Arc, REWEIGHT_TO_INITIAL>(*fst, &fst_pushed, kPushWeights|kPushLabels, kDelta);
266 
267  std::cout <<" printing after pushing\n";
268  {
269  FstPrinter<Arc> fstprinter(fst_pushed, sptr, sptr, NULL, false, true, "\t");
270  fstprinter.Print(&std::cout, "standard output");
271  }
272 
273  assert(RandEquivalent(*fst, fst_pushed, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
274 
275  delete fst;
276  delete fst_copy_orig;
277 }
278 
279 // Don't instantiate with log semiring, as RandEquivalent may fail.
280 template<class Arc> void TestMinimize() {
281  typedef typename Arc::Label Label;
282  typedef typename Arc::StateId StateId;
283  typedef typename Arc::Weight Weight;
284 
285  VectorFst<Arc> *fst = new VectorFst<Arc>();
286  int n_syms = 2 + kaldi::Rand() % 5, n_states = 3 + kaldi::Rand() % 10, n_arcs = 5 + kaldi::Rand() % 30, n_final = 1 + kaldi::Rand()%3; // Up to 2 unique symbols.
287  std::cout << "Testing pre-determinize with "<<n_syms<<" symbols, "<<n_states<<" states and "<<n_arcs<<" arcs and "<<n_final<<" final states.\n";
288  SymbolTable *sptr =NULL;
289 
290  std::vector<Label> all_syms; // including epsilon.
291  // Put symbols in the symbol table from 1..n_syms-1.
292  for (size_t i = 0;i < (size_t)n_syms;i++)
293  all_syms.push_back(i);
294 
295  // Create states.
296  std::vector<StateId> all_states;
297  for (size_t i = 0;i < (size_t)n_states;i++) {
298  StateId this_state = fst->AddState();
299  if (i == 0) fst->SetStart(i);
300  all_states.push_back(this_state);
301  }
302  // Set final states.
303  for (size_t j = 0;j < (size_t)n_final;j++) {
304  StateId id = all_states[kaldi::Rand() % n_states];
305  Weight weight = (Weight)(0.33*(kaldi::Rand() % 5) );
306  printf("calling SetFinal with %d and %f\n", id, weight.Value());
307  fst->SetFinal(id, weight);
308  }
309  // Create arcs.
310  for (size_t i = 0;i < (size_t)n_arcs;i++) {
311  Arc a;
312  a.nextstate = all_states[kaldi::Rand() % n_states];
313  a.ilabel = all_syms[kaldi::Rand() % n_syms];
314  a.olabel = all_syms[kaldi::Rand() % n_syms]; // same input+output vocab.
315  a.weight = (Weight) (0.33*(kaldi::Rand() % 2));
316  StateId start_state = all_states[kaldi::Rand() % n_states];
317  fst->AddArc(start_state, a);
318  }
319 
320  std::cout <<" printing before trimming\n";
321  {
322  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
323  fstprinter.Print(&std::cout, "standard output");
324  }
325  // Trim resulting FST.
326  Connect(fst);
327 
328  std::cout <<" printing after trimming\n";
329  {
330  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
331  fstprinter.Print(&std::cout, "standard output");
332  }
333 
334  VectorFst<Arc> *fst_copy_orig = new VectorFst<Arc>(*fst);
335 
336  std::vector<Label> extra_syms;
337  if (fst->Start() != kNoStateId) { // "Connect" did not make it empty....
338  PreDeterminize(fst, 1000, &extra_syms);
339  }
340 
341  std::cout <<" printing after predeterminization\n";
342  {
343  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
344  fstprinter.Print(&std::cout, "standard output");
345  }
346 
347 
348  { // Remove epsilon. All default args.
349  bool connect = true;
350  Weight weight_threshold = Weight::Zero();
351  int64 nstate = -1; // Relates to pruning.
352  double delta = kDelta; // I think a small weight value. Relates to some kind of pruning,
353  // I guess. But with no epsilon cycles, probably doensn't matter.
354  RmEpsilon(fst, connect, weight_threshold, nstate, delta);
355  }
356 
357  std::cout <<" printing after epsilon removal\n";
358  {
359  FstPrinter<Arc> fstprinter(*fst, sptr, sptr, NULL, false, true, "\t");
360  fstprinter.Print(&std::cout, "standard output");
361  }
362  VectorFst<Arc> ofst_orig;
363  VectorFst<Arc> ofst_star;
364 
365  {
366  printf("Determinizing with baseline\n");
367  DeterminizeOptions<Arc> opts; // Default options.
368  Determinize(*fst, &ofst_orig, opts);
369  }
370  {
371  std::cout <<" printing after determinization [baseline]\n";
372  FstPrinter<Arc> fstprinter(ofst_orig, sptr, sptr, NULL, false, true, "\t");
373  fstprinter.Print(&std::cout, "standard output");
374  }
375 
376 
377  {
378  printf("Determinizing with DeterminizeStar to Gallic semiring\n");
379  VectorFst<GallicArc<Arc> > gallic_fst;
380 
381  DeterminizeStar(*fst, &gallic_fst);
382  {
383  std::cout <<" printing after determinization by DeterminizeStar [in gallic]\n";
384  FstPrinter<GallicArc< Arc> > fstprinter(gallic_fst, sptr, sptr, NULL, false, true, "\t");
385  fstprinter.Print(&std::cout, "standard output");
386  }
387 
388 
389  printf("Pushing weights\n");
390  Push(&gallic_fst, REWEIGHT_TO_INITIAL, kDelta);
391 
392  {
393  std::cout <<" printing after pushing weights [in gallic]\n";
394  FstPrinter<GallicArc< Arc> > fstprinter(gallic_fst, sptr, sptr, NULL, false, true, "\t");
395  fstprinter.Print(&std::cout, "standard output");
396  }
397 
398 
399  printf("Minimizing [in Gallic]\n");
400  Minimize(&gallic_fst);
401  {
402  std::cout <<" printing after minimization [in gallic]\n";
403  FstPrinter<GallicArc< Arc> > fstprinter(gallic_fst, sptr, sptr, NULL, false, true, "\t");
404  fstprinter.Print(&std::cout, "standard output");
405  }
406 
407  printf("Converting gallic back to regular [my approach]\n");
409  typename Arc::Weight, GALLIC_LEFT> > fwfst(gallic_fst);
410  {
411  std::cout <<" printing factor-weight FST\n";
412  FstPrinter<GallicArc< Arc> > fstprinter(fwfst, sptr, sptr, NULL, false, true, "\t");
413  fstprinter.Print(&std::cout, "standard output");
414  }
415 
416  Map(fwfst, &ofst_star, FromGallicMapper<Arc, GALLIC_LEFT>());
417 
418  {
419  std::cout <<" printing after converting back to regular FST\n";
420  FstPrinter<Arc> fstprinter(ofst_star, sptr, sptr, NULL, false, true, "\t");
421  fstprinter.Print(&std::cout, "standard output");
422  }
423 
424  }
425 
426 
427  assert(RandEquivalent(ofst_orig, ofst_star, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
428 
429 
430  int64 num_removed = DeleteISymbols(&ofst_star, extra_syms);
431  std::cout <<" printing after removing "<<num_removed<<" instances of extra symbols\n";
432  {
433  FstPrinter<Arc> fstprinter(ofst_star, sptr, sptr, NULL, false, true, "\t");
434  fstprinter.Print(&std::cout, "standard output");
435  }
436 
437  std::cout <<" Checking equivalent to original FST.\n";
438  // giving Rand() as a seed stops the random number generator from always being reset to
439  // the same point each time, while maintaining determinism of the test.
440  assert(RandEquivalent(ofst_star, *fst_copy_orig, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
441 
442  delete fst;
443  delete fst_copy_orig;
444 }
445 
446 
447 template<class Arc, class inttype> void TestStringRepository() {
448  typedef typename Arc::Label Label;
449 
451 
452  int N = 100;
453  if (sizeof(inttype) == 1) N = 64;
454  std::vector<std::vector<Label> > strings(N);
455  std::vector<inttype> ids(N);
456 
457  for (int i = 0;i < N;i++) {
458  size_t len = kaldi::Rand() % 4;
459  std::vector<Label> vec;
460  for (size_t j = 0;j < len;j++) vec.push_back( (kaldi::Rand()%10) + 150*(kaldi::Rand()%2)); // make it have reasonable range.
461  if (i < 500 && vec.size() == 0) ids[i] = sr.IdOfEmpty();
462  else if (i < 500 && vec.size() == 1) ids[i] = sr.IdOfLabel(vec[0]);
463  else ids[i] = sr.IdOfSeq(vec);
464 
465  strings[i] = vec;
466  }
467 
468  for (int i = 0;i < N;i++) {
469  std::vector<Label> tmpv;
470  tmpv.push_back(10); // just put in garbage.
471  sr.SeqOfId(ids[i], &tmpv);
472  assert(tmpv == strings[i]);
473  assert(sr.IdOfSeq(strings[i]) == ids[i]);
474  if (strings[i].size() == 0) assert(ids[i] == sr.IdOfEmpty());
475  if (strings[i].size() == 1) assert(ids[i] == sr.IdOfLabel(strings[i][0]));
476 
477  if (sizeof(inttype) != 1) {
478  size_t prefix_len = kaldi::Rand() % (strings[i].size() + 1);
479  inttype s2 = sr.RemovePrefix(ids[i], prefix_len);
480  std::vector<Label> vec2;
481  sr.SeqOfId(s2, &vec2);
482  for (size_t j = 0;j < strings[i].size()-prefix_len;j++) {
483  assert(vec2[j] == strings[i][j+prefix_len]);
484  }
485  }
486 
487  }
488 }
489 
490 
491 } // end namespace fst
492 
493 
494 int main() {
495  for (int i = 0;i < 3;i++) { // We would need more iterations to check
496  // this properly.
497  fst::TestStringRepository<fst::StdArc, int>();
498  fst::TestStringRepository<fst::StdArc, unsigned int>();
499  // Not for use with char, but this helps reveal some kinds of bugs.
500  fst::TestStringRepository<fst::StdArc, unsigned char>();
501  fst::TestStringRepository<fst::StdArc, char>();
502  fst::TestDeterminizeGeneral<fst::StdArc>();
503  fst::TestDeterminize<fst::StdArc>();
504  // fst::TestDeterminize2<fst::StdArc>();
505  fst::TestPush<fst::StdArc>();
506  fst::TestMinimize<fst::StdArc>();
507  }
508 }
fst::StdArc::StateId StateId
TrivialFactorWeightFst takes as template parameter a FactorIterator as defined above.
void PreDeterminize(MutableFst< Arc > *fst, typename Arc::Label first_new_sym, std::vector< Int > *symsOut)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void TestDeterminize()
StringId IdOfSeq(const std::vector< Label > &v)
void TestDeterminizeGeneral()
StringId RemovePrefix(StringId id, size_t prefix_len)
void SeqOfId(StringId id, std::vector< Label > *v)
StringId IdOfLabel(Label l)
fst::StdArc::Label Label
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
fst::StdArc::Weight Weight
void TestDeterminize2()
void TestMinimize()
void TestStringRepository()
int main()
int64 DeleteISymbols(MutableFst< Arc > *fst, std::vector< typename Arc::Label > isyms)
bool DeterminizeStar(F &ifst, MutableFst< typename F::Arc > *ofst, float delta, bool *debug_ptr, int max_states, bool allow_partial)
This function implements the normal version of DeterminizeStar, in which the output strings are repre...