grammar-fst.h
Go to the documentation of this file.
1 // decoder/grammar-fst.h
2 
3 // Copyright 2018 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_DECODER_GRAMMAR_FST_H_
21 #define KALDI_DECODER_GRAMMAR_FST_H_
22 
35 #include "fst/fstlib.h"
37 
38 namespace fst {
39 
40 
41 // GrammarFstArc is an FST Arc type which differs from the normal StdArc type by
42 // having the state-id be 64 bits, enough to store two indexes: the higher 32
43 // bits for the FST-instance index, and the lower 32 bits for the state within
44 // that FST-instance.
45 // Obviously this leads to very high-numbered state indexes, which might be
46 // a problem in some circumstances, but the decoder code doesn't store arrays
47 // indexed by state, it uses hashes, so this isn't a problem.
48 struct GrammarFstArc {
49  typedef fst::TropicalWeight Weight;
50  typedef int Label; // OpenFst's StdArc uses int; this is for compatibility.
51  typedef int64 StateId;
52 
53  Label ilabel;
54  Label olabel;
55  Weight weight;
56  StateId nextstate;
57 
59 
60  GrammarFstArc(Label ilabel, Label olabel, Weight weight, StateId nextstate)
61  : ilabel(ilabel),
62  olabel(olabel),
63  weight(std::move(weight)),
64  nextstate(nextstate) {}
65 };
66 
67 #define KALDI_GRAMMAR_FST_SPECIAL_WEIGHT 4096.0
68 
69 class GrammarFst;
70 
71 // Declare that we'll be overriding class ArcIterator for class GrammarFst.
72 // This wouldn't work if we were fully using the OpenFst framework,
73 // e.g. if we had GrammarFst inherit from class Fst.
74 template<> class ArcIterator<GrammarFst>;
75 
76 
96 class GrammarFst {
97  public:
98  typedef GrammarFstArc Arc;
99  typedef TropicalWeight Weight;
100 
101  // StateId is actually int64. The high-order 32 bits are interpreted as an
102  // instance_id, i.e. and index into the instances_ vector; the low-order 32
103  // bits are the state index in the FST instance.
105 
106  // The StateId of the individual FST instances (int, currently).
108 
109  typedef Arc::Label Label;
110 
111 
143  GrammarFst(
144  int32 nonterm_phones_offset,
145  std::shared_ptr<const ConstFst<StdArc> > top_fst,
146  const std::vector<std::pair<int32, std::shared_ptr<const ConstFst<StdArc> > > > &ifsts);
147 
151  GrammarFst(const GrammarFst &other) = default;
152 
155 
156  // This Write function allows you to dump a GrammarFst to disk as a single
157  // object. It only supports binary mode, but the option is allowed for
158  // compatibility with other Kaldi read/write functions (it will crash if
159  // binary == false).
160  void Write(std::ostream &os, bool binary) const;
161 
162  // Reads the format that Write() outputs. Will crash if binary == false.
163  void Read(std::istream &os, bool binary);
164 
165  StateId Start() const {
166  // the top 32 bits of the 64-bit state-id will be zero, because the
167  // top FST instance has instance-id = 0.
168  return static_cast<StateId>(top_fst_->Start());
169  }
170 
171  Weight Final(StateId s) const {
172  // If the fst-id (top 32 bits of s) is nonzero, this state is not final,
173  // because we need to return to the top-level FST before we can be final.
174  if (s != static_cast<StateId>(static_cast<int32>(s))) {
175  return Weight::Zero();
176  } else {
177  BaseStateId base_state = static_cast<BaseStateId>(s);
178  Weight ans = top_fst_->Final(base_state);
179  if (ans.Value() == KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) {
180  return Weight::Zero();
181  } else {
182  return ans;
183  }
184  }
185  }
186 
187  // This is called in LatticeFasterDecoder. As an implementation shortcut, if
188  // the state is an expanded state, we return 1, meaning 'yes, there are input
189  // epsilons'; the calling code doesn't actually care about the exact number.
190  inline size_t NumInputEpsilons(StateId s) const {
191  // Compare with the constructor of ArcIterator.
192  int32 instance_id = s >> 32;
193  BaseStateId base_state = static_cast<int32>(s);
194  const GrammarFst::FstInstance &instance = instances_[instance_id];
195  const ConstFst<StdArc> *base_fst = instance.fst;
196  if (base_fst->Final(base_state).Value() != KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) {
197  return base_fst->NumInputEpsilons(base_state);
198  } else {
199  return 1;
200  }
201  }
202 
203  inline std::string Type() const { return "grammar"; }
204 
205  ~GrammarFst();
206  private:
207 
209 
210  friend class ArcIterator<GrammarFst>;
211 
212  // sets up nonterminal_map_.
213  void InitNonterminalMap();
214 
215  // sets up entry_arcs_[i]. We do this only on demand, as each one is
216  // accessed, so that if there are a lot of nonterminals, this object doesn't
217  // too much work when it is initialized. Each call to this function only
218  // takes time O(number of left-context phones), which is quite small, but we'd
219  // like to avoid that if possible.
220  //
221  // This function returns true if it successfully initialized the
222  // entry_arcs_[i]; and false if it left it empty because
223  bool InitEntryArcs(int32 i);
224 
225  // sets up instances_ with the top-level instance.
226  void InitInstances();
227 
228  // Does the initialization tasks after nonterm_phones_offset_,
229  // top_fsts_ and ifsts_ have been set up
230  void Init();
231 
232  // clears everything.
233  void Destroy();
234 
235  /*
236  This utility function sets up a map from "left-context phone", meaning
237  either a phone index or the index of the symbol #nonterm_bos, to
238  an arc-index leaving a particular state in an FST (i.e. an index
239  that we could use to Seek() to the matching arc).
240 
241  @param [in] fst The FST that is being entered (or reentered)
242  @param [in] entry_state The state in 'fst' which is being entered
243  (or reentered); will be fst.Start() if it's being
244  entered. It must have arcs with ilabels decodable as
245  (nonterminal_symbol, left_context_phone). Will either be the
246  start state (if 'nonterminal_symbol' corresponds to
247  #nonterm_begin), or an internal state (if 'nonterminal_symbol'
248  corresponds to #nonterm_reenter). The arc-indexes of those
249  arcs will be the values we set in 'phone_to_arc'
250  @param [in] nonterminal_symbol The index in phones.txt of the
251  nonterminal symbol we expect to be encoded in the ilabels
252  of the arcs leaving 'entry_state'. Will either correspond
253  to #nonterm_begin or #nonterm_reenter.
254  @param [out] phone_to_arc We output the map from left_context_phone
255  to the arc-index (i.e. the index we'd have to Seek() to
256  in an arc-iterator set up for the state 'entry_state).
257  */
258  void InitEntryOrReentryArcs(
259  const ConstFst<StdArc> &fst,
260  int32 entry_state,
261  int32 nonterminal_symbol,
262  std::unordered_map<int32, int32> *phone_to_arc);
263 
264 
266  return nonterm_phones_offset_ + static_cast<int32>(n);
267  }
281  void DecodeSymbol(Label label,
282  int32 *nonterminal_symbol,
283  int32 *left_context_phone);
284 
285 
286  // This function creates and returns an ExpandedState corresponding to a
287  // particular state-id in the FstInstance for this instance_id. It is called
288  // when we have determined that an ExpandedState needs to be created and that
289  // it is not currently present. It creates and returns it; the calling code
290  // needs to add it to the expanded_states map for its FST instance.
291  ExpandedState *ExpandState(int32 instance_id, BaseStateId state_id);
292 
293  // Called from ExpandState() when the nonterminal type on the arcs is
294  // #nonterm_end, this implements ExpandState() for that case.
295  ExpandedState *ExpandStateEnd(int32 instance_id, BaseStateId state_id);
296 
297  // Called from ExpandState() when the nonterminal type on the arcs is a
298  // user-defined nonterminal, this implements ExpandState() for that case.
299  ExpandedState *ExpandStateUserDefined(int32 instance_id, BaseStateId state_id);
300 
301  // Called from ExpandStateUserDefined(), this function attempts to look up the
302  // pair (nonterminal, state) in the map
303  // instances_[instance_id].child_instances. If it exists (because this
304  // return-state has been expanded before), it returns the value it found;
305  // otherwise it creates the child-instance and returns its newly created
306  // instance-id.
307  inline int32 GetChildInstanceId(int32 instance_id, int32 nonterminal,
308  int32 state);
309 
341  static inline void CombineArcs(const StdArc &leaving_arc,
342  const StdArc &arriving_arc,
343  float cost_correction,
344  StdArc *arc);
345 
352  inline ExpandedState *GetExpandedState(int32 instance_id,
353  BaseStateId state_id) {
354  std::unordered_map<BaseStateId, ExpandedState*> &expanded_states =
355  instances_[instance_id].expanded_states;
356 
357  std::unordered_map<BaseStateId, ExpandedState*>::iterator iter =
358  expanded_states.find(state_id);
359  if (iter != expanded_states.end()) {
360  return iter->second;
361  } else {
362  ExpandedState *ans = ExpandState(instance_id, state_id);
363  // Don't use the reference 'expanded_states'; it could have been
364  // invalidated.
365  instances_[instance_id].expanded_states[state_id] = ans;
366  return ans;
367  }
368  }
369 
376  struct ExpandedState {
377  // The final-prob for expanded states is always zero; to avoid
378  // corner cases, we ensure this via adding epsilon arcs where
379  // needed.
380 
381  // fst-instance index of destination state (we will have ensured previously
382  // that this is the same for all outgoing arcs).
384 
385  // List of arcs out of this state, where the 'nextstate' element will be the
386  // lower-order 32 bits of the destination state and the higher order bits
387  // will be given by 'dest_fst_instance'. We do it this way, instead of
388  // constructing a vector<Arc>, in order to simplify the ArcIterator code and
389  // avoid unnecessary branches in loops over arcs.
390  // We guarantee that this 'arcs' array will always be nonempty; this
391  // is to avoid certain hassles on Windows with automated bounds-checking.
392  std::vector<StdArc> arcs;
393  };
394 
395 
396  // An FstInstance is a copy of an FST. The instance numbered zero is for
397  // top_fst_, and (to state it approximately) whenever any FST instance invokes
398  // another FST a new instance will be generated on demand.
399  struct FstInstance {
400  // ifst_index is the index into the ifsts_ vector that corresponds to this
401  // FST instance, or -1 if this is the top-level instance.
403 
404  // Pointer to the FST corresponding to this instance: it will equal top_fst_
405  // if ifst_index == -1, or ifsts_[ifst_index].second otherwise.
406  const ConstFst<StdArc> *fst;
407 
408  // 'expanded_states', which will be populated on demand as states in this
409  // FST instance are accessed, will only contain entries for states in this
410  // FST that the final-prob's value equal to
411  // KALDI_GRAMMAR_FST_SPECIAL_WEIGHT. (That final-prob value is used as a
412  // kind of signal to this code that the state needs expansion).
413  std::unordered_map<BaseStateId, ExpandedState*> expanded_states;
414 
415  // 'child_instances', which is populated on demand as states in this FST
416  // instance are accessed, is logically a map from pair (nonterminal_index,
417  // return_state) to instance_id. When we encounter an arc in our FST with a
418  // user-defined nonterminal indexed 'nonterminal_index' on its ilabel, and
419  // with 'return_state' as its nextstate, we look up that pair
420  // (nonterminal_index, return_state) in this map to see whether there already
421  // exists an FST instance for that. If it exists then the transition goes to
422  // that FST instance; if not, then we create a new one. The 'return_state'
423  // that's part of the key in this map would be the same as the 'parent_state'
424  // in that child FST instance, and of course the 'parent_instance' in
425  // that child FST instance would be the instance_id of this instance.
426  //
427  // In most cases each return_state would only have a single
428  // nonterminal_index, making the 'nonterminal_index' in the key *usually*
429  // redundant, but in principle it could happen that two user-defined
430  // nonterminals might share the same return-state.
431  std::unordered_map<int64, int32> child_instances;
432 
433  // The instance-id of the FST we return to when we are done with this one
434  // (or -1 if this is the top-level FstInstance so there is nowhere to
435  // return).
437 
438  // The state in the FST of 'parent_instance' at which we expanded this FST
439  // instance, and to which we return (actually we return to the next-states
440  // of arcs out of 'parent_state').
442 
443  // 'parent_reentry_arcs' is a map from left-context-phone (i.e. either a
444  // phone index or #nonterm_bos), to an arc-index, which we could use to
445  // Seek() in an arc-iterator for state parent_state in the FST-instance
446  // 'parent_instance'. It's set up when we create this FST instance. (The
447  // arcs used to enter this instance are not located here, they can be
448  // located in entry_arcs_[instance_id]). We make use of reentry_arcs when
449  // we expand states in this FST that have #nonterm_end on their arcs,
450  // leading to final-states, which signal a return to the parent
451  // FST-instance.
452  std::unordered_map<int32, int32> parent_reentry_arcs;
453  };
454 
455  // The integer id of the symbol #nonterm_bos in phones.txt.
457 
458  // The top-level FST passed in by the user; contains the start state and
459  // final-states, and may invoke FSTs in 'ifsts_' (which can also invoke
460  // each other recursively).
461  std::shared_ptr<const ConstFst<StdArc> > top_fst_;
462 
463  // A list of pairs (nonterm, fst), where 'nonterm' is a user-defined
464  // nonterminal symbol as numbered in phones.txt (e.g. #nonterm:foo), and
465  // 'fst' is the corresponding FST.
466  std::vector<std::pair<int32, std::shared_ptr<const ConstFst<StdArc> > > > ifsts_;
467 
468  // Maps from the user-defined nonterminals like #nonterm:foo as numbered
469  // in phones.txt, to the corresponding index into 'ifsts_', i.e. the ifst_index.
470  std::unordered_map<int32, int32> nonterminal_map_;
471 
472  // entry_arcs_ will have the same dimension as ifsts_. Each entry_arcs_[i]
473  // is a map from left-context phone (i.e. either a phone-index or
474  // #nonterm_bos) to the corresponding arc-index leaving the start-state in
475  // the FST 'ifsts_[i].second'.
476  // We populate this only on demand as each one is needed (except for the
477  // first one, which we populate immediately as a kind of sanity check).
478  // Doing it on-demand prevents this object's initialization from being
479  // nontrivial in the case where there are a lot of nonterminals.
480  std::vector<std::unordered_map<int32, int32> > entry_arcs_;
481 
482  // The FST instances. Initially it is a vector with just one element
483  // representing top_fst_, and it will be populated with more elements on
484  // demand. An instance_id refers to an index into this vector.
485  std::vector<FstInstance> instances_;
486 };
487 
488 
495 template <>
496 class ArcIterator<GrammarFst> {
497  public:
498  using Arc = typename GrammarFst::Arc;
499  using BaseArc = StdArc;
500  using StateId = typename Arc::StateId; // int64
501  using BaseStateId = typename StdArc::StateId; // int
503 
504  // Caution: uses const_cast to evade const rules on GrammarFst. This is for
505  // compatibility with how things work in OpenFst.
506  inline ArcIterator(const GrammarFst &fst_in, StateId s) {
507  GrammarFst &fst = const_cast<GrammarFst&>(fst_in);
508  // 'instance_id' is the high order bits of the state.
509  int32 instance_id = s >> 32;
510  // 'base_state' is low order bits of the state. It's important to
511  // explicitly say int32 below, not BaseStateId == int, which might on some
512  // compilers be a 64-bit type.
513  BaseStateId base_state = static_cast<int32>(s);
514  const GrammarFst::FstInstance &instance = fst.instances_[instance_id];
515  const ConstFst<StdArc> *base_fst = instance.fst;
516  if (base_fst->Final(base_state).Value() != KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) {
517  // A normal state
518  dest_instance_ = instance_id;
519  base_fst->InitArcIterator(s, &data_);
520  i_ = 0;
521  } else {
522  // A special state
523  ExpandedState *expanded_state = fst.GetExpandedState(instance_id,
524  base_state);
525  dest_instance_ = expanded_state->dest_fst_instance;
526  // it's ok to leave the other members of data_ uninitialized, as they will
527  // never be interrogated.
528  data_.arcs = &(expanded_state->arcs[0]);
529  data_.narcs = expanded_state->arcs.size();
530  i_ = 0;
531  }
532  // Ideally we want to call CopyArcToTemp() now, but we rely on the fact that
533  // the calling code needs to call Done() before accessing Value(); we call
534  // CopyArcToTemp() from Done(). Of course this is slightly against the
535  // semantics of Done(), but it's more efficient to have Done() call
536  // CopyArcToTemp() than this function or Next(), as Done() already has to
537  // test that the arc-iterator has not reached the end.
538  }
539 
540  inline bool Done() {
541  if (i_ < data_.narcs) {
542  CopyArcToTemp();
543  return false;
544  } else {
545  return true;
546  }
547  }
548 
549  inline void Next() {
550  i_++;
551  // Note: logically, at this point we should do:
552  // if (i_ < data_.size)
553  // CopyArcToTemp();
554  // Instead we move this CopyArcToTemp() invocation into Done(), which we
555  // know will always be called after Next() and before Value(), because the
556  // user has no other way of knowing whether the iterator is still valid.
557  // This is for efficiency.
558  }
559 
560  inline const Arc &Value() const { return arc_; }
561 
562  private:
563 
564  inline void CopyArcToTemp() {
565  const StdArc &src = data_.arcs[i_];
566  arc_.ilabel = src.ilabel;
567  arc_.olabel = src.olabel;
568  arc_.weight = src.weight;
569  arc_.nextstate = (static_cast<int64>(dest_instance_) << 32) |
570  src.nextstate;
571  }
572 
573  // The members of 'data_' that we use are:
574  // const Arc *arcs;
575  // size_t narcs;
576  ArcIteratorData<StdArc> data_;
577 
578 
579  int32 dest_instance_; // The index of the FstInstance that we transition to from
580  // this state.
581  size_t i_; // i_ is the index into the 'arcs' pointer.
582 
583  Arc arc_; // 'Arc' is the current arc in the GrammarFst, that this iterator
584  // is pointing to. It will be a copy of data_.arcs[i], except with
585  // the 'nextstate' modified to encode dest_instance_ in the higher
586  // order bits. Making a copy is of course unnecessary for the most
587  // part, but Value() needs to return a reference; we rely on the
588  // compiler to optimize out any unnecessary moves of data.
589 };
590 
600 void CopyToVectorFst(GrammarFst *grammar_fst,
601  VectorFst<StdArc> *vector_fst);
602 
638 void PrepareForGrammarFst(int32 nonterm_phones_offset,
639  VectorFst<StdArc> *fst);
640 
641 
642 } // end namespace fst
643 
644 
645 #endif
fst::StdArc::StateId StateId
std::unordered_map< BaseStateId, ExpandedState * > expanded_states
Definition: grammar-fst.h:413
const ConstFst< StdArc > * fst
Definition: grammar-fst.h:406
std::vector< std::unordered_map< int32, int32 > > entry_arcs_
Definition: grammar-fst.h:480
ArcIterator(const GrammarFst &fst_in, StateId s)
Definition: grammar-fst.h:506
TropicalWeight Weight
Definition: grammar-fst.h:99
ExpandedState * GetExpandedState(int32 instance_id, BaseStateId state_id)
Called from the ArcIterator constructor when we encounter an FST state with nonzero final-prob...
Definition: grammar-fst.h:352
std::unordered_map< int32, int32 > nonterminal_map_
Definition: grammar-fst.h:470
std::vector< StdArc > arcs
Definition: grammar-fst.h:392
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
fst::StdArc StdArc
GrammarFstArc Arc
Definition: grammar-fst.h:98
kaldi::int32 int32
std::vector< FstInstance > instances_
Definition: grammar-fst.h:485
NonterminalValues
An anonymous enum to define some values for symbols used in our grammar-fst framework.
Arc::Label Label
Definition: grammar-fst.h:109
#define KALDI_GRAMMAR_FST_SPECIAL_WEIGHT
Definition: grammar-fst.h:67
void PrepareForGrammarFst(int32 nonterm_phones_offset, VectorFst< StdArc > *fst)
This function prepares &#39;ifst&#39; for use in GrammarFst: it ensures that it has the expected properties...
Definition: grammar-fst.cc:982
uint64 data_
std::unordered_map< int32, int32 > parent_reentry_arcs
Definition: grammar-fst.h:452
GrammarFstArc(Label ilabel, Label olabel, Weight weight, StateId nextstate)
Definition: grammar-fst.h:60
StdArc::StateId BaseStateId
Definition: grammar-fst.h:107
GrammarFst()
This constructor should only be used prior to calling Read().
Definition: grammar-fst.h:154
Represents an expanded state in an FstInstance.
Definition: grammar-fst.h:376
std::vector< std::pair< int32, std::shared_ptr< const ConstFst< StdArc > > > > ifsts_
Definition: grammar-fst.h:466
struct rnnlm::@11::@12 n
int32 GetPhoneSymbolFor(enum NonterminalValues n)
Definition: grammar-fst.h:265
ArcIteratorData< StdArc > data_
Definition: grammar-fst.h:576
std::string Type() const
Definition: grammar-fst.h:203
GrammarFst is an FST that is &#39;stitched together&#39; from multiple FSTs, that can recursively incorporate...
Definition: grammar-fst.h:96
typename Arc::StateId StateId
Definition: grammar-fst.h:500
const Arc & Value() const
Definition: grammar-fst.h:560
typename StdArc::StateId BaseStateId
Definition: grammar-fst.h:501
Weight Final(StateId s) const
Definition: grammar-fst.h:171
fst::TropicalWeight Weight
Definition: grammar-fst.h:49
std::shared_ptr< const ConstFst< StdArc > > top_fst_
Definition: grammar-fst.h:461
int32 nonterm_phones_offset_
Definition: grammar-fst.h:456
size_t NumInputEpsilons(StateId s) const
Definition: grammar-fst.h:190
Arc::StateId StateId
Definition: grammar-fst.h:104
std::unordered_map< int64, int32 > child_instances
Definition: grammar-fst.h:431
typename GrammarFst::Arc Arc
Definition: grammar-fst.h:498
void CopyToVectorFst(GrammarFst *grammar_fst, VectorFst< StdArc > *vector_fst)
This function copies a GrammarFst to a VectorFst (intended mostly for testing and comparison purposes...
Definition: grammar-fst.cc:988
StateId Start() const
Definition: grammar-fst.h:165