context-fst-test.cc
Go to the documentation of this file.
1 // fstext/context-fst-test.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "fstext/context-fst.h"
21 #include "fstext/fst-test-utils.h"
22 #include "tree/context-dep.h"
23 #include "util/kaldi-io.h"
24 #include "base/kaldi-math.h"
25 
26 namespace fst
27 {
28 using std::vector;
29 using std::cout;
30 
31 // GenAcceptorFromSequence generates a linear acceptor (identical input+output symbols) that has this
32 // sequence of symbols, and
33 template<class Arc>
34 static VectorFst<Arc> *GenAcceptorFromSequence(const vector<typename Arc::Label> &symbols, float cost) {
35  typedef typename Arc::Weight Weight;
36  typedef typename Arc::StateId StateId;
37 
38  vector<float> split_cost(symbols.size()+1, 0.0); // for #-arcs + end-state.
39  { // compute split_cost. it must sum to "cost".
40  std::set<int32> indices;
41  size_t num_indices = 1 + (kaldi::Rand() % split_cost.size());
42  while (indices.size() < num_indices) indices.insert(kaldi::Rand() % split_cost.size());
43  for (std::set<int32>::iterator iter = indices.begin(); iter != indices.end(); ++iter) {
44  split_cost[*iter] = cost / num_indices;
45  }
46  }
47 
48  VectorFst<Arc> *fst = new VectorFst<Arc>();
49  StateId cur_state = fst->AddState();
50  fst->SetStart(cur_state);
51  for (size_t i = 0; i < symbols.size(); i++) {
52  StateId next_state = fst->AddState();
53  Arc arc;
54  arc.ilabel = symbols[i];
55  arc.olabel = symbols[i];
56  arc.nextstate = next_state;
57  arc.weight = (Weight) split_cost[i];
58  fst->AddArc(cur_state, arc);
59  cur_state = next_state;
60 
61  }
62  fst->SetFinal(cur_state, (Weight)split_cost[symbols.size()]);
63  return fst;
64 }
65 
66 
67 
68 // CheckPhones is used to test the correctness of an FST that is the result of
69 // composition with a ContextFst.
70 template<class Arc>
71 static float CheckPhones(const VectorFst<Arc> &linear_fst,
72  const vector<typename Arc::Label> &phone_ids,
73  const vector<typename Arc::Label> &disambig_ids,
74  const vector<typename Arc::Label> &phone_seq,
75  const vector<vector<typename Arc::Label> > &ilabel_info,
76  int N, int P) {
77  typedef typename Arc::Label Label;
78  typedef typename Arc::StateId StateId;
79  typedef typename Arc::Weight Weight;
80 
81  assert(kaldi::IsSorted(phone_ids)); // so we can do binary_search.
82 
83 
84  vector<int32> input_syms;
85  vector<int32> output_syms;
86  Weight tot_cost;
87  bool ans = GetLinearSymbolSequence(linear_fst, &input_syms,
88  &output_syms, &tot_cost);
89  assert(ans); // should be linear.
90 
91  vector<int32> phone_seq_check;
92  for (size_t i = 0; i < output_syms.size(); i++)
93  if (std::binary_search(phone_ids.begin(), phone_ids.end(), output_syms[i]))
94  phone_seq_check.push_back(output_syms[i]);
95 
96  assert(phone_seq_check == phone_seq);
97 
98  vector<vector<int32> > input_syms_long;
99  for (size_t i = 0; i < input_syms.size(); i++) {
100  Label isym = input_syms[i];
101  if (ilabel_info[isym].size() == 0) continue; // epsilon.
102  if ( (ilabel_info[isym].size() == 1 &&
103  ilabel_info[isym][0] <= 0) ) continue; // disambig.
104  input_syms_long.push_back(ilabel_info[isym]);
105  }
106 
107  for (size_t i = 0; i < input_syms_long.size(); i++) {
108  vector<int32> phone_context_window(N); // phone at pos i will be at pos P in this window.
109  int pos = ((int)i) - P; // pos of first phone in window [ may be out of range] .
110  for (int j = 0; j < N; j++, pos++) {
111  if (static_cast<size_t>(pos) < phone_seq.size()) phone_context_window[j] = phone_seq[pos];
112  else phone_context_window[j] = 0; // 0 is a special symbol that context-dep-itf expects to see
113  // when no phone is present due to out-of-window. context-fst knows about this too.
114  }
115  assert(input_syms_long[i] == phone_context_window);
116  }
117  return tot_cost.Value();
118 }
119 
120 
121 
122 
123 template<class Arc>
124 static VectorFst<Arc> *GenRandPhoneSeq(vector<typename Arc::Label> &phone_syms,
125  vector<typename Arc::Label> &disambig_syms,
126  typename Arc::Label subsequential_symbol,
127  int num_subseq_syms,
128  float seq_prob,
129  vector<typename Arc::Label> *phoneseq_out) {
130  KALDI_ASSERT(phoneseq_out != NULL);
131  typedef typename Arc::Label Label;
132  // Generate an FST that is a random phone sequence, ending
133  // with "num_subseq_syms" subsequential symbols. It will
134  // have disambiguation symbols randomly interspersed throughout.
135  // The number of phones is random (possibly zero).
136  size_t len = (kaldi::Rand() % 4) * (kaldi::Rand() % 3); // up to 3*2=6 phones.
137  float disambig_prob = 0.33;
138  phoneseq_out->clear();
139  vector<Label> syms; // the phones
140  for (size_t i = 0; i < len; i++) {
141  while (kaldi::RandUniform() < disambig_prob) syms.push_back(disambig_syms[kaldi::Rand() % disambig_syms.size()]);
142  Label phone_id = phone_syms[kaldi::Rand() % phone_syms.size()];
143  phoneseq_out->push_back(phone_id); // record in output the underlying phone sequence.
144  syms.push_back(phone_id);
145  }
146  for (size_t i = 0; static_cast<int32>(i) < num_subseq_syms; i++) {
147  while (kaldi::RandUniform() < disambig_prob) syms.push_back(disambig_syms[kaldi::Rand() % disambig_syms.size()]);
148  syms.push_back(subsequential_symbol);
149  }
150  while (kaldi::RandUniform() < disambig_prob) syms.push_back(disambig_syms[kaldi::Rand() % disambig_syms.size()]);
151 
152  // OK, now have the symbols of the FST as a vector.
153  return GenAcceptorFromSequence<Arc>(syms, seq_prob);
154 }
155 
156 // Don't instantiate with log semiring, as RandEquivalent may fail.
157 // TestContestFst also test ReadILabelInfo and WriteILabelInfo.
158 static void TestContextFst(bool verbose, bool use_matcher) {
159  typedef StdArc Arc;
160  typedef Arc::Label Label;
161  typedef Arc::StateId StateId;
162  typedef Arc::Weight Weight;
163 
164  // Generate a random set of phones.
165  size_t num_phones = 1 + kaldi::Rand() % 10;
166  std::set<int32> phones_set;
167  while (phones_set.size() < num_phones) phones_set.insert(1 + kaldi::Rand() % (num_phones + 5)); // don't use 0 [== epsilon]
168  vector<int32> phones;
169  kaldi::CopySetToVector(phones_set, &phones);
170 
171  int N = 1 + kaldi::Rand() % 4; // Context size, in range 1..4.
172  int P = kaldi::Rand() % N; // 1.. N-1.
173  if (verbose) std::cout << "N = "<< N << ", P = "<<P<<'\n';
174 
175  Label subsequential_symbol = 1000;
176  vector<int32> disambig_syms;
177  for (size_t i =0; i < 5; i++) disambig_syms.push_back(500 + i);
178  vector<int32> phone_syms;
179  for (size_t i = 0; i < phones.size();i++) phone_syms.push_back(phones[i]);
180 
181 
182  InverseContextFst inv_cfst(subsequential_symbol,
183  phones, disambig_syms,
184  N, P);
185 
186 
187  /* Now create random phone-sequences and compose them with the context FST.
188  */
189 
190  for (size_t p = 0; p < 10; p++) {
191  vector<int32> phone_seq;
192  int num_subseq = N - P - 1; // zero if P == N-1, i.e. P is last element, i.e. left-context only.
193  float tot_cost = 20.0 * kaldi::RandUniform();
194  VectorFst<Arc> *f = GenRandPhoneSeq<Arc>(phone_syms, disambig_syms, subsequential_symbol, num_subseq, tot_cost, &phone_seq);
195  if (verbose) {
196  std::cout << "Sequence FST is:\n";
197  { // Try to print the fst.
198  FstPrinter<Arc> fstprinter(*f, NULL, NULL, NULL, false, true, "\t");
199  fstprinter.Print(&std::cout, "standard output");
200  }
201  }
202 
203  VectorFst<Arc> fst_composed;
204 
205  ComposeDeterministicOnDemandInverse(*f, &inv_cfst, &fst_composed);
206 
207 
208  // Testing WriteILabelInfo and ReadILabelInfo.
209  {
210  bool binary = (kaldi::Rand() % 2 == 0);
211  WriteILabelInfo(kaldi::Output("tmpf", binary).Stream(),
212  binary, inv_cfst.IlabelInfo());
213 
214  bool binary_in;
215  vector<vector<int32> > ilabel_info;
216  kaldi::Input ki("tmpf", &binary_in);
217  ReadILabelInfo(ki.Stream(),
218  binary_in, &ilabel_info);
219  assert(ilabel_info == inv_cfst.IlabelInfo());
220  }
221 
222 
223  if (verbose) {
224  std::cout << "Composed FST is:\n";
225  { // Try to print the fst.
226  FstPrinter<Arc> fstprinter(fst_composed, NULL, NULL, NULL, false, true, "\t");
227  fstprinter.Print(&std::cout, "standard output");
228  }
229  }
230 
231  // now check the composed FST.
232  float tot_cost_check = CheckPhones<Arc>(fst_composed,
233  phone_syms,
234  disambig_syms,
235  phone_seq,
236  inv_cfst.IlabelInfo(),
237  N, P);
238  kaldi::AssertEqual(tot_cost, tot_cost_check);
239 
240  delete f;
241  }
242 
243  unlink("tmpf");
244 }
245 
246 
247 } // namespace fst
248 
249 int main() {
250 
251  for (int i = 0;i < 16;i++) {
252  bool verbose = (i < 4);
253  bool use_matcher = ( (i/4) % 2 == 0);
254  fst::TestContextFst(verbose, use_matcher);
255  }
256 }
fst::StdArc::StateId StateId
const std::vector< std::vector< int32 > > & IlabelInfo() const
Definition: context-fst.h:194
void WriteILabelInfo(std::ostream &os, bool binary, const vector< vector< int32 > > &info)
Utility function for writing ilabel-info vectors to disk.
Definition: context-fst.cc:325
void CopySetToVector(const std::set< T > &s, std::vector< T > *v)
Copies the elements of a set to a vector.
Definition: stl-utils.h:86
float RandUniform(struct RandomState *state=NULL)
Returns a random number strictly between 0 and 1.
Definition: kaldi-math.h:151
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
fst::StdArc StdArc
kaldi::int32 int32
int main()
static VectorFst< Arc > * GenAcceptorFromSequence(const vector< typename Arc::Label > &symbols, float cost)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void ComposeDeterministicOnDemandInverse(const Fst< Arc > &right, DeterministicOnDemandFst< Arc > *left, MutableFst< Arc > *fst_composed)
This function does &#39;*fst_composed = Compose(Inverse(*fst2), fst1)&#39; Note that the arguments are revers...
void ReadILabelInfo(std::istream &is, bool binary, vector< vector< int32 > > *info)
Utility function for reading ilabel-info vectors from disk.
Definition: context-fst.cc:335
std::istream & Stream()
Definition: kaldi-io.cc:826
fst::StdArc::Label Label
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
fst::StdArc::Weight Weight
static void TestContextFst(bool verbose, bool use_matcher)
static float CheckPhones(const VectorFst< Arc > &linear_fst, const vector< typename Arc::Label > &phone_ids, const vector< typename Arc::Label > &disambig_ids, const vector< typename Arc::Label > &phone_seq, const vector< vector< typename Arc::Label > > &ilabel_info, int N, int P)
bool IsSorted(const std::vector< T > &vec)
Returns true if the vector is sorted.
Definition: stl-utils.h:47
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
static VectorFst< Arc > * GenRandPhoneSeq(vector< typename Arc::Label > &phone_syms, vector< typename Arc::Label > &disambig_syms, typename Arc::Label subsequential_symbol, int num_subseq_syms, float seq_prob, vector< typename Arc::Label > *phoneseq_out)