context-fst.cc
Go to the documentation of this file.
1 // fstext/context-fst.cc
2 
3 // Copyright 2018 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "fstext/context-fst.h"
21 #include "base/kaldi-error.h"
22 
23 namespace fst {
24 using std::vector;
25 
26 
28  Label subsequential_symbol,
29  const vector<int32>& phones,
30  const vector<int32>& disambig_syms,
31  int32 context_width,
32  int32 central_position):
33  context_width_(context_width),
34  central_position_(central_position),
35  phone_syms_(phones),
36  disambig_syms_(disambig_syms),
37  subsequential_symbol_(subsequential_symbol) {
38 
39  { // This block checks the inputs.
40  KALDI_ASSERT(subsequential_symbol != 0
41  && disambig_syms_.count(subsequential_symbol) == 0
42  && phone_syms_.count(subsequential_symbol) == 0);
43  if (phone_syms_.empty())
44  KALDI_WARN << "Context FST created but there are no phone symbols: probably "
45  "input FST was empty.";
46  KALDI_ASSERT(phone_syms_.count(0) == 0 && disambig_syms_.count(0) == 0 &&
48  for (size_t i = 0; i < phones.size(); i++) {
49  KALDI_ASSERT(disambig_syms_.count(phones[i]) == 0);
50  }
51  }
52 
53  // empty vector, will be the ilabel_info vector that corresponds to epsilon,
54  // in case our FST needs to output epsilons.
55  vector<int32> empty_vec;
56  Label epsilon_label = FindLabel(empty_vec);
57 
58  // epsilon_vec is the phonetic context window we have at the very start of a
59  // sequence, meaning "no real phones have been seen yet".
60  vector<int32> epsilon_vec(context_width_ - 1, 0);
61  StateId start_state = FindState(epsilon_vec);
62 
63  KALDI_ASSERT(epsilon_label == 0 && start_state == 0);
64 
66  // We add a symbol whose sequence representation is [ 0 ], and whose
67  // symbol-id is 1. This is treated as a disambiguation symbol, we call it
68  // #-1 in printed form. It is necessary to ensure that all determinizable
69  // LG's will have determinizable CLG's. The problem it fixes is quite
70  // subtle-- it relates to reordering of disambiguation symbols (they appear
71  // earlier in CLG than in LG, relative to phones), and the fact that if a
72  // disambig symbol appears at the very start of a sequence in CLG, it's not
73  // clear exatly where it appeared on the corresponding sequence at the input
74  // of LG.
75  vector<int32> pseudo_eps_vec;
76  pseudo_eps_vec.push_back(0);
77  pseudo_eps_symbol_= FindLabel(pseudo_eps_vec);
79  } else {
80  pseudo_eps_symbol_ = 0; // use actual epsilon.
81  }
82 }
83 
84 
86  std::vector<int32> *phone_seq) {
87  if (!phone_seq->empty()) {
88  phone_seq->erase(phone_seq->begin());
89  phone_seq->push_back(label);
90  }
91 }
92 
94  const std::vector<int32> &seq, Label label,
95  std::vector<int32> *full_phone_sequence) {
96  int32 context_width = context_width_;
97  full_phone_sequence->reserve(context_width);
98  full_phone_sequence->insert(full_phone_sequence->end(),
99  seq.begin(), seq.end());
100  full_phone_sequence->push_back(label);
101  for (int32 i = central_position_ + 1; i < context_width; i++) {
102  if ((*full_phone_sequence)[i] == subsequential_symbol_) {
103  (*full_phone_sequence)[i] = 0;
104  }
105  }
106 }
107 
108 
110  KALDI_ASSERT(static_cast<size_t>(s) < state_seqs_.size());
111 
112  const vector<int32> &phone_context = state_seqs_[s];
113 
114  KALDI_ASSERT(phone_context.size() == context_width_ - 1);
115 
116  bool has_final_prob;
117 
118  if (central_position_ < context_width_ - 1) {
119  has_final_prob = (phone_context[central_position_] == subsequential_symbol_);
120  // if phone_context[central_position_] != subsequential_symbol_ then we have
121  // pending phones-in-context that we still need to output, so we need to
122  // consume more subsequential symbols before we can terminate.
123  } else {
124  has_final_prob = true;
125  }
126  return has_final_prob ? Weight::One() : Weight::Zero();
127 }
128 
130  KALDI_ASSERT(ilabel != 0 && static_cast<size_t>(s) < state_seqs_.size() &&
131  state_seqs_[s].size() == context_width_ - 1);
132 
133  if (IsDisambigSymbol(ilabel)) {
134  // A disambiguation-symbol self-loop arc.
135  CreateDisambigArc(s, ilabel, arc);
136  return true;
137  } else if (IsPhoneSymbol(ilabel)) {
138  const vector<int32> &seq = state_seqs_[s];
139  if (!seq.empty() && seq.back() == subsequential_symbol_) {
140  return false; // A real phone is not allowed to follow the subsequential
141  // symbol.
142  }
143 
144  // next_seq will be 'seq' shifted left by 1, with 'ilabel' appended.
145  vector<int32> next_seq(seq);
146  ShiftSequenceLeft(ilabel, &next_seq);
147 
148  // full-seq will be the full context window of size context_width_.
149  vector<int32> full_seq;
150  GetFullPhoneSequence(seq, ilabel, &full_seq);
151 
152  StateId next_s = FindState(next_seq);
153 
154  CreatePhoneOrEpsArc(s, next_s, ilabel, full_seq, arc);
155  return true;
156  } else if (ilabel == subsequential_symbol_) {
157  const vector<int32> &seq = state_seqs_[s];
158 
159  if (central_position_ + 1 == context_width_ ||
161  // We already had "enough" subsequential symbols in a row and don't want to
162  // accept any more, or we'd be making the subsequential symbol the central phone.
163  return false;
164  }
165 
166  // full-seq will be the full context window of size context_width_.
167  vector<int32> full_seq;
168  GetFullPhoneSequence(seq, ilabel, &full_seq);
169 
170  vector<int32> next_seq(seq);
171  ShiftSequenceLeft(ilabel, &next_seq);
172  StateId next_s = FindState(next_seq);
173 
174  CreatePhoneOrEpsArc(s, next_s, ilabel, full_seq, arc);
175  return true;
176  } else {
177  KALDI_ERR << "ContextFst: CreateArc, invalid ilabel supplied [confusion "
178  << "about phone list or disambig symbols?]: " << ilabel;
179  }
180  return false; // won't get here. suppress compiler error.
181 }
182 
183 
185  // Creates a self-loop arc corresponding to the disambiguation symbol.
186  vector<int32> label_info; // This will be a vector containing just [ -olabel ].
187  label_info.push_back(-ilabel); // olabel is a disambiguation symbol. Use its negative
188  // so we can more easily distinguish them from phones.
189  Label olabel = FindLabel(label_info);
190  arc->ilabel = ilabel;
191  arc->olabel = olabel;
192  arc->weight = Weight::One();
193  arc->nextstate = s; // self-loop.
194 }
195 
197  Label ilabel,
198  const vector<int32> &phone_seq,
199  Arc *arc) {
201 
202  arc->ilabel = ilabel;
203  arc->weight = Weight::One();
204  arc->nextstate = dest;
205  if (phone_seq[central_position_] == 0) {
206  // This can happen at the beginning of the graph. In this case we don't
207  // output a real phone, we createdt an epsilon arc (but sometimes we need to
208  // use a special disambiguation symbol instead of epsilon).
209  arc->olabel = pseudo_eps_symbol_;
210  } else {
211  // We have a phone in the central position.
212  arc->olabel = FindLabel(phone_seq);
213  }
214 }
215 
217  // Finds state-id corresponding to this vector of phones. Inserts it if
218  // necessary.
219  KALDI_ASSERT(static_cast<int32>(seq.size()) == context_width_ - 1);
220  VectorToStateMap::const_iterator iter = state_map_.find(seq);
221  if (iter == state_map_.end()) { // Not already in map.
222  StateId this_state_id = (StateId)state_seqs_.size();
223  state_seqs_.push_back(seq);
224  state_map_[seq] = this_state_id;
225  return this_state_id;
226  } else {
227  return iter->second;
228  }
229 }
230 
231 StdArc::Label InverseContextFst::FindLabel(const vector<int32> &label_vec) {
232  // Finds the ilabel corresponding to this vector (creates a new ilabel if
233  // necessary).
234  VectorToLabelMap::const_iterator iter = ilabel_map_.find(label_vec);
235  if (iter == ilabel_map_.end()) { // Not already in map.
236  Label this_label = ilabel_info_.size();
237  ilabel_info_.push_back(label_vec);
238  ilabel_map_[label_vec] = this_label;
239  return this_label;
240  } else {
241  return iter->second;
242  }
243 }
244 
245 
246 void ComposeContext(const vector<int32> &disambig_syms_in,
247  int32 context_width, int32 central_position,
248  VectorFst<StdArc> *ifst,
249  VectorFst<StdArc> *ofst,
250  vector<vector<int32> > *ilabels_out,
251  bool project_ifst) {
252  KALDI_ASSERT(ifst != NULL && ofst != NULL);
253  KALDI_ASSERT(context_width > 0);
254  KALDI_ASSERT(central_position >= 0);
255  KALDI_ASSERT(central_position < context_width);
256 
257  vector<int32> disambig_syms(disambig_syms_in);
258  std::sort(disambig_syms.begin(), disambig_syms.end());
259 
260  vector<int32> all_syms;
261  GetInputSymbols(*ifst, false/*no eps*/, &all_syms);
262  std::sort(all_syms.begin(), all_syms.end());
263  vector<int32> phones;
264  for (size_t i = 0; i < all_syms.size(); i++)
265  if (!std::binary_search(disambig_syms.begin(),
266  disambig_syms.end(), all_syms[i]))
267  phones.push_back(all_syms[i]);
268 
269  // Get subsequential symbol that does not clash with
270  // any disambiguation symbol or symbol in the FST.
271  int32 subseq_sym = 1;
272  if (!all_syms.empty())
273  subseq_sym = std::max(subseq_sym, all_syms.back() + 1);
274  if (!disambig_syms.empty())
275  subseq_sym = std::max(subseq_sym, disambig_syms.back() + 1);
276 
277  // if central_position == context_width-1, it's left-context, and no
278  // subsequential symbol is needed.
279  if (central_position != context_width-1) {
280  AddSubsequentialLoop(subseq_sym, ifst);
281  if (project_ifst) {
282  fst::Project(ifst, fst::PROJECT_INPUT);
283  }
284  }
285 
286  InverseContextFst inv_c(subseq_sym, phones, disambig_syms,
287  context_width, central_position);
288 
289  // The following statement is equivalent to the following
290  // (if FSTs had the '*' operator for composition):
291  // (*ofst) = inv(inv_c) * (*ifst)
292  ComposeDeterministicOnDemandInverse(*ifst, &inv_c, ofst);
293 
294  inv_c.SwapIlabelInfo(ilabels_out);
295 }
296 
298  MutableFst<StdArc> *fst) {
299  typedef StdArc Arc;
300  typedef typename Arc::StateId StateId;
301  typedef typename Arc::Weight Weight;
302 
303  vector<StateId> final_states;
304  for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
305  StateId s = siter.Value();
306  if (fst->Final(s) != Weight::Zero()) final_states.push_back(s);
307  }
308 
309  StateId superfinal = fst->AddState();
310  Arc arc(subseq_symbol, 0, Weight::One(), superfinal);
311  fst->AddArc(superfinal, arc); // loop at superfinal.
312  fst->SetFinal(superfinal, Weight::One());
313 
314  for (size_t i = 0; i < final_states.size(); i++) {
315  StateId s = final_states[i];
316  fst->AddArc(s, Arc(subseq_symbol, 0, fst->Final(s), superfinal));
317  // No, don't remove the final-weights of the original states..
318  // this is so we can add the subsequential loop in cases where
319  // there is no context, and it won't hurt.
320  // fst->SetFinal(s, Weight::Zero());
321  arc.nextstate = final_states[i];
322  }
323 }
324 
325 void WriteILabelInfo(std::ostream &os, bool binary,
326  const vector<vector<int32> > &info) {
327  int32 size = info.size();
328  kaldi::WriteBasicType(os, binary, size);
329  for (int32 i = 0; i < size; i++) {
330  kaldi::WriteIntegerVector(os, binary, info[i]);
331  }
332 }
333 
334 
335 void ReadILabelInfo(std::istream &is, bool binary,
336  vector<vector<int32> > *info) {
337  int32 size = info->size();
338  kaldi::ReadBasicType(is, binary, &size);
339  info->resize(size);
340  for (int32 i = 0; i < size; i++) {
341  kaldi::ReadIntegerVector(is, binary, &((*info)[i]));
342  }
343 }
344 
345 SymbolTable *CreateILabelInfoSymbolTable(const vector<vector<int32> > &info,
346  const SymbolTable &phones_symtab,
347  std::string separator,
348  std::string initial_disambig) { // e.g. separator = "/", initial-disambig="#-1"
349  KALDI_ASSERT(!info.empty() && info[0].empty());
350  SymbolTable *ans = new SymbolTable("ilabel-info-symtab");
351  int64 s = ans->AddSymbol(phones_symtab.Find(static_cast<int64>(0)));
352  assert(s == 0);
353  for (size_t i = 1; i < info.size(); i++) {
354  if (info[i].size() == 0) {
355  KALDI_ERR << "Invalid ilabel-info";
356  }
357  if (info[i].size() == 1 &&
358  info[i][0] <= 0) {
359  if (info[i][0] == 0) { // special symbol at start that we want to call #-1.
360  s = ans->AddSymbol(initial_disambig);
361  if (s != i) {
362  KALDI_ERR << "Disambig symbol " << initial_disambig
363  << " already in vocab";
364  }
365  } else {
366  std::string disambig_sym = phones_symtab.Find(-info[i][0]);
367  if (disambig_sym == "") {
368  KALDI_ERR << "Disambig symbol " << -info[i][0]
369  << " not in phone symbol-table";
370  }
371  s = ans->AddSymbol(disambig_sym);
372  if (s != i) {
373  KALDI_ERR << "Disambig symbol " << disambig_sym
374  << " already in vocab";
375  }
376  }
377  } else {
378  // is a phone-context-window.
379  std::string newsym;
380  for (size_t j = 0; j < info[i].size(); j++) {
381  std::string phonesym = phones_symtab.Find(info[i][j]);
382  if (phonesym == "") {
383  KALDI_ERR << "Symbol " << info[i][j]
384  << " not in phone symbol-table";
385  }
386  if (j != 0) newsym += separator;
387  newsym += phonesym;
388  }
389  int64 s = ans->AddSymbol(newsym);
390  if (s != static_cast<int64>(i)) {
391  KALDI_ERR << "Some problem with duplicate symbols";
392  }
393  }
394  }
395  return ans;
396 }
397 
398 
399 
400 
401 } // end namespace fst
fst::StdArc::StateId StateId
void WriteILabelInfo(std::ostream &os, bool binary, const vector< vector< int32 > > &info)
Utility function for writing ilabel-info vectors to disk.
Definition: context-fst.cc:325
VectorToStateMap state_map_
Definition: context-fst.h:314
StdArc::StateId StateId
Definition: context-fst.h:155
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void CreateDisambigArc(StateId s, Label ilabel, Arc *arc)
Create disambiguation-symbol self-loop arc; where &#39;ilabel&#39; must correspond to a disambiguation symbol...
Definition: context-fst.cc:184
fst::StdArc StdArc
void CreatePhoneOrEpsArc(StateId src, StateId dst, Label ilabel, const std::vector< int32 > &phone_seq, Arc *arc)
Creates an arc, this function is to be called only when &#39;ilabel&#39; corresponds to a phone...
Definition: context-fst.cc:196
Label FindLabel(const std::vector< int32 > &label_info)
Finds the label index corresponding to this context-window of phones (likely of width context_width_)...
Definition: context-fst.cc:231
std::vector< std::vector< int32 > > ilabel_info_
Definition: context-fst.h:333
virtual Weight Final(StateId s)
Definition: context-fst.cc:109
kaldi::int32 int32
InverseContextFst(Label subsequential_symbol, const std::vector< int32 > &phones, const std::vector< int32 > &disambig_syms, int32 context_width, int32 central_position)
Constructor.
Definition: context-fst.cc:27
bool IsDisambigSymbol(Label lab)
Definition: context-fst.h:213
void GetInputSymbols(const Fst< Arc > &fst, bool include_eps, std::vector< I > *symbols)
GetInputSymbols gets the list of symbols on the input of fst (including epsilon, if include_eps == tr...
kaldi::ConstIntegerSet< Label > phone_syms_
Definition: context-fst.h:276
StdArc::Label Label
Definition: context-fst.h:157
void ComposeDeterministicOnDemandInverse(const Fst< Arc > &right, DeterministicOnDemandFst< Arc > *left, MutableFst< Arc > *fst_composed)
This function does &#39;*fst_composed = Compose(Inverse(*fst2), fst1)&#39; Note that the arguments are revers...
void ReadILabelInfo(std::istream &is, bool binary, vector< vector< int32 > > *info)
Utility function for reading ilabel-info vectors from disk.
Definition: context-fst.cc:335
VectorToLabelMap ilabel_map_
Definition: context-fst.h:325
virtual bool GetArc(StateId s, Label ilabel, Arc *arc)
Note: ilabel must not be epsilon.
Definition: context-fst.cc:129
void GetFullPhoneSequence(const std::vector< int32 > &seq, Label label, std::vector< int32 > *full_phone_sequence)
This utility function does something equivalent to the following 3 steps: *full_phone_sequence = seq;...
Definition: context-fst.cc:93
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void ShiftSequenceLeft(Label label, std::vector< int32 > *phone_seq)
If phone_seq is nonempty then this function it left by one and appends &#39;label&#39; to it...
Definition: context-fst.cc:85
void SwapIlabelInfo(std::vector< std::vector< int32 > > *vec)
Definition: context-fst.h:200
#define KALDI_ERR
Definition: kaldi-error.h:147
StateId FindState(const std::vector< int32 > &seq)
Returns the state-id corresponding to this vector of phones; creates the state it if necessary...
Definition: context-fst.cc:216
#define KALDI_PARANOID_ASSERT(cond)
Definition: kaldi-error.h:206
void AddSubsequentialLoop(StdArc::Label subseq_symbol, MutableFst< StdArc > *fst)
Modifies an FST so that it transuces the same paths, but the input side of the paths can all have the...
Definition: context-fst.cc:297
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::StdArc::Label Label
fst::StdArc::Weight Weight
kaldi::ConstIntegerSet< Label > disambig_syms_
Definition: context-fst.h:285
void ComposeContext(const vector< int32 > &disambig_syms_in, int32 context_width, int32 central_position, VectorFst< StdArc > *ifst, VectorFst< StdArc > *ofst, vector< vector< int32 > > *ilabels_out, bool project_ifst)
Used in the command-line tool fstcomposecontext.
Definition: context-fst.cc:246
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
SymbolTable * CreateILabelInfoSymbolTable(const vector< vector< int32 > > &info, const SymbolTable &phones_symtab, std::string separator, std::string initial_disambig)
The following function is mainly of use for printing and debugging.
Definition: context-fst.cc:345
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
StdArc::Weight Weight
Definition: context-fst.h:156
bool IsPhoneSymbol(Label lab)
Definition: context-fst.h:215
std::vector< std::vector< int32 > > state_seqs_
Definition: context-fst.h:318