factor-test.cc
Go to the documentation of this file.
1 // fstext/factor-test.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "fstext/factor.h"
22 #include "fstext/fstext-utils.h"
23 #include "fstext/fst-test-utils.h"
24 #include "base/kaldi-math.h"
25 
26 
27 namespace fst
28 {
29 using std::vector;
30 
31 // Don't instantiate with log semiring, as RandEquivalent may fail.
32 template<class Arc> static void TestFactor() {
33  typedef typename Arc::Label Label;
34  typedef typename Arc::StateId StateId;
35  typedef typename Arc::Weight Weight;
36 
37  VectorFst<Arc> fst;
38  int n_syms = 2 + kaldi::Rand() % 5, n_arcs = 5 + kaldi::Rand() % 30, n_final = 1 + kaldi::Rand()%10;
39 
40  SymbolTable symtab("my-symbol-table"), *sptr = &symtab;
41 
42  vector<Label> all_syms; // including epsilon.
43  // Put symbols in the symbol table from 1..n_syms-1.
44  for (size_t i = 0;i < (size_t)n_syms;i++) {
45  std::stringstream ss;
46  if (i == 0) ss << "<eps>";
47  else ss<<i;
48  Label cur_lab = sptr->AddSymbol(ss.str());
49  assert(cur_lab == (Label)i);
50  all_syms.push_back(cur_lab);
51  }
52  assert(all_syms[0] == 0);
53 
54  fst.AddState();
55  int cur_num_states = 1;
56  for (int i = 0; i < n_arcs; i++) {
57  StateId src_state = kaldi::Rand() % cur_num_states;
58  StateId dst_state;
59  if (kaldi::RandUniform() < 0.1) dst_state = kaldi::Rand() % cur_num_states;
60  else {
61  dst_state = cur_num_states++; fst.AddState();
62  }
63  Arc arc;
64  if (kaldi::RandUniform() < 0.5) arc.ilabel = all_syms[kaldi::Rand()%all_syms.size()];
65  else arc.ilabel = 0;
66  if (kaldi::RandUniform() < 0.5) arc.olabel = all_syms[kaldi::Rand()%all_syms.size()];
67  else arc.olabel = 0;
68  arc.weight = (Weight) (0 + 0.1*(kaldi::Rand() % 5));
69  arc.nextstate = dst_state;
70  fst.AddArc(src_state, arc);
71  }
72  for (int i = 0; i < n_final; i++) {
73  fst.SetFinal(kaldi::Rand() % cur_num_states, (Weight) (0 + 0.1*(kaldi::Rand() % 5)));
74  }
75 
76  if (kaldi::RandUniform() < 0.8) fst.SetStart(0); // usually leads to nicer examples.
77  else fst.SetStart(kaldi::Rand() % cur_num_states);
78 
79  std::cout <<" printing before trimming\n";
80  {
81  FstPrinter<Arc> fstprinter(fst, sptr, sptr, NULL, false, true, "\t");
82  fstprinter.Print(&std::cout, "standard output");
83  }
84  // Trim resulting FST.
85  Connect(&fst);
86 
87  std::cout <<" printing after trimming\n";
88  {
89  FstPrinter<Arc> fstprinter(fst, sptr, sptr, NULL, false, true, "\t");
90  fstprinter.Print(&std::cout, "standard output");
91  }
92 
93  if (fst.Start() == kNoStateId) return; // "Connect" made it empty.
94 
95  VectorFst<Arc> fst_pushed;
96  Push<Arc, REWEIGHT_TO_INITIAL>(fst, &fst_pushed, kPushLabels);
97 
98  VectorFst<Arc> fst_factored;
99  vector<vector<typename Arc::Label> > symbols;
100 
101  Factor(fst, &fst_factored, &symbols);
102 
103  // Check no epsilons in "symbols".
104  for (size_t i = 0; i < symbols.size(); i++)
105  assert(symbols[i].size() == 0 || *(std::min(symbols[i].begin(), symbols[i].end())) > 0);
106 
107  VectorFst<Arc> fst_factored_pushed;
108  vector<vector<typename Arc::Label> > symbols_pushed;
109  Factor(fst_pushed, &fst_factored_pushed, &symbols_pushed);
110 
111  std::cout << "Unfactored has "<<fst.NumStates()<<" states, factored has "<<fst_factored.NumStates()<<", and pushed+factored has "<<fst_factored_pushed.NumStates()<<'\n';
112 
113  assert(fst_factored.NumStates() <= fst.NumStates());
114  // assert(fst_factored_pushed.NumStates() <= fst_factored.NumStates()); // pushing should only help. [ no, it doesn't]
115  assert(fst_factored_pushed.NumStates() <= fst_pushed.NumStates());
116 
117  VectorFst<Arc> fst_factored_copy(fst_factored);
118 
119  VectorFst<Arc> fst_factored_unfactored(fst_factored);
120  ExpandInputSequences(symbols, &fst_factored_unfactored);
121 
122  VectorFst<Arc> factor_fst;
123  CreateFactorFst(symbols, &factor_fst);
124  VectorFst<Arc> fst_factored_unfactored2;
125  Compose(factor_fst, fst_factored, &fst_factored_unfactored2);
126 
127  ExpandInputSequences(symbols_pushed, &fst_factored_pushed);
128 
129  assert(RandEquivalent(fst, fst_factored_unfactored, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
130 
131  assert(RandEquivalent(fst, fst_factored_unfactored2, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
132 
133  assert(RandEquivalent(fst, fst_factored_pushed, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/));
134 
135  { // Have tested for equivalence; now do another test: that FactorFst actually finds all
136  // the factors. Do this by inserting factors using ExpandInputSequences and making sure it gets
137  // rid of them all.
138  Label max_label = *(std::max_element(all_syms.begin(), all_syms.end()));
139  vector<vector<Label> > new_labels(max_label+1);
140  for (Label l = 1; l < static_cast<Label>(new_labels.size()); l++) {
141  int n = kaldi::Rand() % 5;
142  for (int i = 0; i < n; i++) new_labels[l].push_back(kaldi::Rand() % 100);
143  }
144  VectorFst<Arc> fst_expanded(fst);
145  ExpandInputSequences(new_labels, &fst_expanded);
146 
147  vector<vector<Label> > factors;
148  VectorFst<Arc> fst_reduced;
149  Factor(fst_expanded, &fst_reduced, &factors);
150  assert(fst_reduced.NumStates() <= fst.NumStates()); // Checking that it found all the factors.
151  }
152 
153  { // This block test MapInputSymbols [but relies on the correctness of Factor
154  // and ExpandInputSequences to do so].
155 
156  std::map<Label, Label> symbols_reverse_map; // from new->old.
157  symbols_reverse_map[0] = 0; // map eps to eps.
158  for (Label i = 1; i < static_cast<Label>(symbols.size()); i++) {
159  Label new_i;
160  do {
161  new_i = kaldi::Rand() % (symbols.size() + 20);
162  } while (symbols_reverse_map.count(new_i) == 1);
163  symbols_reverse_map[new_i] = i;
164  }
165  vector<vector<Label> > symbols_new;
166  vector<Label> symbol_map(symbols.size()); // from old->new.
167  typename std::map<Label, Label>::iterator iter = symbols_reverse_map.begin();
168  for (; iter != symbols_reverse_map.end(); iter++) {
169  Label new_label = iter->first, old_label = iter->second;
170  if (new_label >= static_cast<Label>(symbols_new.size())) symbols_new.resize(new_label+1);
171  symbols_new[new_label] = symbols[old_label];
172  symbol_map[old_label] = new_label;
173  }
174  MapInputSymbols(symbol_map, &fst_factored_copy);
175  ExpandInputSequences(symbols_new, &fst_factored_copy);
176  assert(RandEquivalent(fst, fst_factored_copy,
177  5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/,
178  100/*path length-- max?*/));
179  }
180 
181 }
182 
183 
184 } // namespace fst
185 
186 int main() {
187  using namespace fst;
188  for (int i = 0;i < 25;i++) {
189  TestFactor<fst::StdArc>();
190  }
191 }
fst::StdArc::StateId StateId
float RandUniform(struct RandomState *state=NULL)
Returns a random number strictly between 0 and 1.
Definition: kaldi-math.h:151
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
int main()
Definition: factor-test.cc:186
static void TestFactor()
Definition: factor-test.cc:32
struct rnnlm::@11::@12 n
void CreateFactorFst(const std::vector< std::vector< I > > &sequences, MutableFst< Arc > *fst)
The function CreateFactorFst will create an FST that expands out the "factors" that are the indices o...
Definition: factor-inl.h:250
void Factor(const Fst< Arc > &fst, MutableFst< Arc > *ofst, std::vector< std::vector< I > > *symbols_out)
Factor identifies linear chains of states with an olabel (if any) only on the first arc of the chain...
Definition: factor-inl.h:69
fst::StdArc::Label Label
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
fst::StdArc::Weight Weight
void MapInputSymbols(const std::vector< I > &symbol_mapping, MutableFst< Arc > *fst)
void ExpandInputSequences(const std::vector< std::vector< I > > &sequences, MutableFst< Arc > *fst)
ExpandInputSequences expands out the input symbols into sequences of input symbols.
Definition: factor-inl.h:163