arpa-lm-compiler.cc
Go to the documentation of this file.
1 // lm/arpa-lm-compiler.cc
2 
3 // Copyright 2009-2011 Gilles Boulianne
4 // Copyright 2016 Smart Action LLC (kkm)
5 // Copyright 2017 Xiaohui Zhang
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include <algorithm>
23 #include <limits>
24 #include <sstream>
25 #include <utility>
26 
27 #include "base/kaldi-math.h"
28 #include "lm/arpa-lm-compiler.h"
29 #include "util/stl-utils.h"
30 #include "util/text-utils.h"
32 
33 namespace kaldi {
34 
36  public:
38  virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0;
39 };
40 
41 namespace {
42 
43 typedef int32 StateId;
44 typedef int32 Symbol;
45 
46 // GeneralHistKey can represent state history in an arbitrarily large n
47 // n-gram model with symbol ids fitting int32.
48 class GeneralHistKey {
49  public:
50  // Construct key from being and end iterators.
51  template<class InputIt>
52  GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) { }
53  // Construct empty history key.
54  GeneralHistKey() : vector_() { }
55  // Return tails of the key as a GeneralHistKey. The tails of an n-gram
56  // w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
57  // key class does not need this operartion).
58  GeneralHistKey Tails() const {
59  return GeneralHistKey(vector_.begin() + 1, vector_.end());
60  }
61  // Keys are equal if represent same state.
62  friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) {
63  return a.vector_ == b.vector_;
64  }
65  // Public typename HashType for hashing.
66  struct HashType : public std::unary_function<GeneralHistKey, size_t> {
67  size_t operator()(const GeneralHistKey& key) const {
68  return VectorHasher<Symbol>().operator()(key.vector_);
69  }
70  };
71 
72  private:
73  std::vector<Symbol> vector_;
74 };
75 
76 // OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit
77 // machine word. allowing significant memory reduction and some runtime
78 // benefit over GeneralHistKey. Since 3 symbols are enough to track history
79 // in a 4-gram model, this optimized key is used for smaller models with up
80 // to 4-gram and symbol values up to 2^21-1.
81 //
82 // See GeneralHistKey for interface requirements of a key class.
83 class OptimizedHistKey {
84  public:
85  enum {
86  kShift = 21, // 21 * 3 = 63 bits for data.
87  kMaxData = (1 << kShift) - 1
88  };
89  template<class InputIt>
90  OptimizedHistKey(InputIt begin, InputIt end) : data_(0) {
91  for (uint32 shift = 0; begin != end; ++begin, shift += kShift) {
92  data_ |= static_cast<uint64>(*begin) << shift;
93  }
94  }
95  OptimizedHistKey() : data_(0) { }
96  OptimizedHistKey Tails() const {
97  return OptimizedHistKey(data_ >> kShift);
98  }
99  friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) {
100  return a.data_ == b.data_;
101  }
102  struct HashType : public std::unary_function<OptimizedHistKey, size_t> {
103  size_t operator()(const OptimizedHistKey& key) const { return key.data_; }
104  };
105 
106  private:
107  explicit OptimizedHistKey(uint64 data) : data_(data) { }
108  uint64 data_;
109 };
110 
111 } // namespace
112 
113 template <class HistKey>
115  public:
117  Symbol sub_eps);
118 
119  virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
120 
121  private:
122  StateId AddStateWithBackoff(HistKey key, float backoff);
123  void CreateBackoff(HistKey key, StateId state, float weight);
124 
125  ArpaLmCompiler *parent_; // Not owned.
126  fst::StdVectorFst* fst_; // Not owned.
127  Symbol bos_symbol_;
128  Symbol eos_symbol_;
129  Symbol sub_eps_;
130 
132  typedef unordered_map<HistKey, StateId,
133  typename HistKey::HashType> HistoryMap;
134  HistoryMap history_;
135 };
136 
137 template <class HistKey>
139  ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps)
140  : parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol),
141  eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) {
142  // The algorithm maintains state per history. The 0-gram is a special state
143  // for empty history. All unigrams (including BOS) backoff into this state.
144  StateId zerogram = fst_->AddState();
145  history_[HistKey()] = zerogram;
146 
147  // Also, if </s> is not treated as epsilon, create a common end state for
148  // all transitions accepting the </s>, since they do not back off. This small
149  // optimization saves about 2% states in an average grammar.
150  if (sub_eps_ == 0) {
151  eos_state_ = fst_->AddState();
152  fst_->SetFinal(eos_state_, 0);
153  }
154 }
155 
156 template <class HistKey>
158  bool is_highest) {
159  // Generally, we do the following. Suppose we are adding an n-gram "A B
160  // C". Then find the node for "A B", add a new node for "A B C", and connect
161  // them with the arc accepting "C" with the specified weight. Also, add a
162  // backoff arc from the new "A B C" node to its backoff state "B C".
163  //
164  // Two notable exceptions are the highest order n-grams, and final n-grams.
165  //
166  // When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM),
167  // the following optimization is performed. There is no point adding a node
168  // for "A B C" with a "C" arc from "A B", since there will be no other
169  // arcs ingoing to this node, and an epsilon backoff arc into the backoff
170  // model "B C", with the weight of \bar{1}. To save a node, create an arc
171  // accepting "C" directly from "A B" to "B C". This saves as many nodes
172  // as there are the highest order n-grams, which is typically about half
173  // the size of a large 3-gram model.
174  //
175  // Indeed, this does not apply to n-grams ending in EOS, since they do not
176  // back off. These are special, as they do not have a back-off state, and
177  // the node for "(..anything..) </s>" is always final. These are handled
178  // in one of the two possible ways, If symbols <s> and </s> are being
179  // replaced by epsilons, neither node nor arc is created, and the logprob
180  // of the n-gram is applied to its source node as final weight. If <s> and
181  // </s> are preserved, then a special final node for </s> is allocated and
182  // used as the destination of the "</s>" acceptor arc.
183  HistKey heads(ngram.words.begin(), ngram.words.end() - 1);
184  typename HistoryMap::iterator source_it = history_.find(heads);
185  if (source_it == history_.end()) {
186  // There was no "A B", therefore the probability of "A B C" is zero.
187  // Print a warning and discard current n-gram.
188  if (parent_->ShouldWarn())
190  << " skipped: no parent (n-1)-gram exists";
191  return;
192  }
193 
194  StateId source = source_it->second;
195  StateId dest;
196  Symbol sym = ngram.words.back();
197  float weight = -ngram.logprob;
198  if (sym == sub_eps_ || sym == 0) {
199  KALDI_ERR << " <eps> or disambiguation symbol " << sym << "found in the ARPA file. ";
200  }
201  if (sym == eos_symbol_) {
202  if (sub_eps_ == 0) {
203  // Keep </s> as a real symbol when not substituting.
204  dest = eos_state_;
205  } else {
206  // Treat </s> as if it was epsilon: mark source final, with the weight
207  // of the n-gram.
208  fst_->SetFinal(source, weight);
209  return;
210  }
211  } else {
212  // For the highest order n-gram, this may find an existing state, for
213  // non-highest, will create one (unless there are duplicate n-grams
214  // in the grammar, which cannot be reliably detected if highest order,
215  // so we better do not do that at all).
216  dest = AddStateWithBackoff(
217  HistKey(ngram.words.begin() + (is_highest ? 1 : 0),
218  ngram.words.end()),
219  -ngram.backoff);
220  }
221 
222  if (sym == bos_symbol_) {
223  weight = 0; // Accepting <s> is always free.
224  if (sub_eps_ == 0) {
225  // <s> is as a real symbol, only accepted in the start state.
226  source = fst_->AddState();
227  fst_->SetStart(source);
228  } else {
229  // The new state for <s> unigram history *is* the start state.
230  fst_->SetStart(dest);
231  return;
232  }
233  }
234 
235  // Add arc from source to dest, whichever way it was found.
236  fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
237  return;
238 }
239 
240 // Find or create a new state for n-gram defined by key, and ensure it has a
241 // backoff transition. The key is either the current n-gram for all but
242 // highest orders, or the tails of the n-gram for the highest order. The
243 // latter arises from the chain-collapsing optimization described above.
244 template <class HistKey>
246  float backoff) {
247  typename HistoryMap::iterator dest_it = history_.find(key);
248  if (dest_it != history_.end()) {
249  // Found an existing state in the history map. Invariant: if the state in
250  // the map, then its backoff arc is in the FST. We are done.
251  return dest_it->second;
252  }
253  // Otherwise create a new state and its backoff arc, and register in the map.
254  StateId dest = fst_->AddState();
255  history_[key] = dest;
256  CreateBackoff(key.Tails(), dest, backoff);
257  return dest;
258 }
259 
260 // Create a backoff arc for a state. Key is a backoff destination that may or
261 // may not exist. When the destination is not found, naturally fall back to
262 // the lower order model, and all the way down until one is found (since the
263 // 0-gram model is always present, the search is guaranteed to terminate).
264 template <class HistKey>
266  HistKey key, StateId state, float weight) {
267  typename HistoryMap::iterator dest_it = history_.find(key);
268  while (dest_it == history_.end()) {
269  key = key.Tails();
270  dest_it = history_.find(key);
271  }
272 
273  // The arc should transduce either <eos> or #0 to <eps>, depending on the
274  // epsilon substitution mode. This is the only case when input and output
275  // label may differ.
276  fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
277 }
278 
280  if (impl_ != NULL)
281  delete impl_;
282 }
283 
285  KALDI_ASSERT(impl_ == NULL);
286  // Use optimized implementation if the grammar is 4-gram or less, and the
287  // maximum attained symbol id will fit into the optimized range.
288  int64 max_symbol = 0;
289  if (Symbols() != NULL)
290  max_symbol = Symbols()->AvailableKey() - 1;
291  // If augmenting the symbol table, assume the worst case when all words in
292  // the model being read are novel.
293  if (Options().oov_handling == ArpaParseOptions::kAddToSymbols)
294  max_symbol += NgramCounts()[0];
295 
296  if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
298  } else {
299  impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
300  KALDI_LOG << "Reverting to slower state tracking because model is large: "
301  << NgramCounts().size() << "-gram with symbols up to "
302  << max_symbol;
303  }
304 }
305 
307  // <s> is invalid in tails, </s> in heads of an n-gram.
308  for (int i = 0; i < ngram.words.size(); ++i) {
309  if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
310  (i + 1 < ngram.words.size()
311  && ngram.words[i] == Options().eos_symbol)) {
312  if (ShouldWarn())
313  KALDI_WARN << LineReference()
314  << " skipped: n-gram has invalid BOS/EOS placement";
315  return;
316  }
317  }
318 
319  bool is_highest = ngram.words.size() == NgramCounts().size();
320  impl_->ConsumeNGram(ngram, is_highest);
321 }
322 
324  fst::StdArc::Label backoff_symbol = sub_eps_;
325  if (backoff_symbol == 0) {
326  // The method of removing redundant states implemented in this function
327  // leads to slow determinization of L o G when people use the older style of
328  // usage of arpa2fst where the --disambig-symbol option was not specified.
329  // The issue seems to be that it creates a non-deterministic FST, while G is
330  // supposed to be deterministic. By 'return'ing below, we just disable this
331  // method if people were using an older script. This method isn't really
332  // that consequential anyway, and people will move to the newer-style
333  // scripts (see current utils/format_lm.sh), so this isn't much of a
334  // problem.
335  return;
336  }
337 
338  fst::StdArc::StateId num_states = fst_.NumStates();
339 
340 
341  // replace the #0 symbols on the input of arcs out of redundant states (states
342  // that are not final and have only a backoff arc leaving them), with <eps>.
343  for (fst::StdArc::StateId state = 0; state < num_states; state++) {
344  if (fst_.NumArcs(state) == 1 && fst_.Final(state) == fst::TropicalWeight::Zero()) {
345  fst::MutableArcIterator<fst::StdVectorFst> iter(&fst_, state);
346  fst::StdArc arc = iter.Value();
347  if (arc.ilabel == backoff_symbol) {
348  arc.ilabel = 0;
349  iter.SetValue(arc);
350  }
351  }
352  }
353 
354  // we could call fst::RemoveEps, and it would have the same effect in normal
355  // cases, where backoff_symbol != 0 and there are no epsilons in unexpected
356  // places, but RemoveEpsLocal is a bit safer in case something weird is going
357  // on; it guarantees not to blow up the FST.
359  KALDI_LOG << "Reduced num-states from " << num_states << " to "
360  << fst_.NumStates();
361 }
362 
363 void ArpaLmCompiler::Check() const {
364  if (fst_.Start() == fst::kNoStateId) {
365  KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol "
366  << Symbols()->Find(Options().bos_symbol) << ".";
367  }
368 }
369 
371  fst_.SetInputSymbols(Symbols());
372  fst_.SetOutputSymbols(Symbols());
373  RemoveRedundantStates();
374  Check();
375 }
376 
377 } // namespace kaldi
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
A hashing function-object for vectors.
Definition: stl-utils.h:216
void RemoveEpsLocal(MutableFst< Arc > *fst)
RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST, using an algorithm that is g...
virtual void ConsumeNGram(const NGram &ngram, bool is_highest)
Lattice::StateId StateId
virtual void ConsumeNGram(const NGram &ngram)
Pure override that must be implemented to process current n-gram.
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
virtual void ConsumeNGram(const NGram &ngram, bool is_highest)=0
fst::StdArc StdArc
float logprob
Log-prob of the n-gram.
kaldi::int32 int32
StateId AddStateWithBackoff(HistKey key, float backoff)
fst::StdVectorFst StdVectorFst
Add novel words to the symbol table.
uint64 data_
virtual void HeaderAvailable()
Override function called to signal that ARPA header with the expected number of n-grams has been read...
void CreateBackoff(HistKey key, StateId state, float weight)
unordered_map< HistKey, StateId, typename HistKey::HashType > HistoryMap
fst::StdVectorFst * fst_
float backoff
log-backoff weight of the n-gram.
float backoff
std::string LineReference() const
Inside ConsumeNGram(), returns a formatted reference to the line being compiled, to print out as part...
std::vector< int32 > words
Symbols in left to right order.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::StdArc::Label Label
ArpaLmCompilerImpl(ArpaLmCompiler *parent, fst::StdVectorFst *fst, Symbol sub_eps)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool ShouldWarn()
Increments warning count, and returns true if a warning should be printed or false if the count has e...
bool operator==(const LatticeWeightTpl< FloatType > &wa, const LatticeWeightTpl< FloatType > &wb)
std::vector< Symbol > vector_
#define KALDI_LOG
Definition: kaldi-error.h:153
A parsed n-gram from ARPA LM file.
virtual void ReadComplete()
Override function called after the last n-gram has been consumed.