const-arpa-lm.h
Go to the documentation of this file.
1 // lm/const-arpa-lm.h
2 
3 // Copyright 2014 Guoguo Chen
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_LM_CONST_ARPA_LM_H_
21 #define KALDI_LM_CONST_ARPA_LM_H_
22 
23 #include <string>
24 #include <vector>
25 
26 #include "base/kaldi-common.h"
28 #include "lm/arpa-file-parser.h"
29 #include "util/common-utils.h"
30 
31 namespace kaldi {
32 
199 // Forward declaration of Auxiliary struct ArpaLine.
200 struct ArpaLine;
201 
204  float f;
205 
207  Int32AndFloat(int32 input_i) : i(input_i) {}
208  Int32AndFloat(float input_f) : f(input_f) {}
209 };
210 
211 class ConstArpaLm {
212  public:
213 
214  // Default constructor, will be used if you are going to load the ConstArpaLm
215  // format language model from disk.
217  lm_states_ = NULL;
218  unigram_states_ = NULL;
219  overflow_buffer_ = NULL;
220  memory_assigned_ = false;
221  initialized_ = false;
222  }
223 
224  // Special constructor, will be used when you initialize ConstArpaLm from
225  // scratch through this constructor.
226  ConstArpaLm(const int32 bos_symbol, const int32 eos_symbol,
227  const int32 unk_symbol, const int32 ngram_order,
228  const int32 num_words, const int32 overflow_buffer_size,
229  const int64 lm_states_size, int32** unigram_states,
230  int32** overflow_buffer, int32* lm_states) :
231  bos_symbol_(bos_symbol), eos_symbol_(eos_symbol),
232  unk_symbol_(unk_symbol), ngram_order_(ngram_order),
233  num_words_(num_words), overflow_buffer_size_(overflow_buffer_size),
234  lm_states_size_(lm_states_size), unigram_states_(unigram_states),
235  overflow_buffer_(overflow_buffer), lm_states_(lm_states) {
236  KALDI_ASSERT(unigram_states_ != NULL);
237  KALDI_ASSERT(overflow_buffer_ != NULL);
238  KALDI_ASSERT(lm_states_ != NULL);
239  KALDI_ASSERT(ngram_order_ > 0);
240  KALDI_ASSERT(bos_symbol_ < num_words_ && bos_symbol_ > 0);
241  KALDI_ASSERT(eos_symbol_ < num_words_ && eos_symbol_ > 0);
242  KALDI_ASSERT(unk_symbol_ < num_words_ &&
243  (unk_symbol_ > 0 || unk_symbol_ == -1));
244  lm_states_end_ = lm_states_ + lm_states_size_ - 1;
245  memory_assigned_ = false;
246  initialized_ = true;
247  }
248 
250  if (memory_assigned_) {
251  delete[] lm_states_;
252  delete[] unigram_states_;
253  delete[] overflow_buffer_;
254  }
255  }
256 
257  // Reads the ConstArpaLm format language model. It calls ReadInternal() or
258  // ReadInternalOldFormat() to do the actual reading.
259  void Read(std::istream &is, bool binary);
260 
261  // Writes the language model in ConstArpaLm format.
262  void Write(std::ostream &os, bool binary) const;
263 
264  // Creates Arpa format language model from ConstArpaLm format, and writes it
265  // to output stream. This will be useful in testing.
266  void WriteArpa(std::ostream &os) const;
267 
268  // Wrapper of GetNgramLogprobRecurse. It first maps possible out-of-vocabulary
269  // words to <unk>, if <unk> is defined, and then calls GetNgramLogprobRecurse.
270  float GetNgramLogprob(const int32 word, const std::vector<int32>& hist) const;
271 
272  // Returns true if the history word sequence <hist> has successor, which means
273  // <hist> will be a state in the FST format language model.
274  bool HistoryStateExists(const std::vector<int32>& hist) const;
275 
276  int32 BosSymbol() const { return bos_symbol_; }
277  int32 EosSymbol() const { return eos_symbol_; }
278  int32 UnkSymbol() const { return unk_symbol_; }
279  int32 NgramOrder() const { return ngram_order_; }
280 
281  private:
282  // Function that loads data from stream to the class.
283  void ReadInternal(std::istream &is, bool binary);
284 
285  // Function that loads data from stream to the class. This is a deprecated one
286  // that handles the old on-disk format. We keep this for back-compatibility
287  // purpose. We have modified the Write() function so for all the new on-disk
288  // format, ReadInternal() will be called.
289  void ReadInternalOldFormat(std::istream &is, bool binary);
290 
291  // Loops up n-gram probability for given word sequence. Backoff is handled by
292  // recursively calling this function.
293  float GetNgramLogprobRecurse(const int32 word,
294  const std::vector<int32>& hist) const;
295 
296  // Given a word sequence, find the address of the corresponding LmState.
297  // Returns NULL if no corresponding LmState is found.
298  //
299  // If the word sequence exists in n-gram language model, but it is a leaf and
300  // is not an unigram, we still return NULL, since there is no LmState struct
301  // reserved for this sequence.
302  int32* GetLmState(const std::vector<int32>& seq) const;
303 
304  // Given a pointer to the parent, find the child_info that corresponds to
305  // given word. The parent has the following structure:
306  // struct LmState {
307  // float logprob;
308  // float backoff_logprob;
309  // int32 num_children;
310  // std::pair<int32, int32> [] children;
311  // }
312  // It returns false if the child is not found.
313  bool GetChildInfo(const int32 word, int32* parent, int32* child_info) const;
314 
315  // Decodes <child_info> to get log probability and child LmState. In the leaf
316  // case, only <logprob> will be returned, and <child_address> will be NULL.
317  void DecodeChildInfo(const int32 child_info, int32* parent,
318  int32** child_lm_state, float* logprob) const;
319 
320  void WriteArpaRecurse(int32* lm_state,
321  const std::vector<int32>& seq,
322  std::vector<ArpaLine> *output) const;
323 
324  // We assign memory in Read(). If it is called, we have to release memory in
325  // the destructor.
327 
328  // Makes sure that the language model has been loaded before using it.
330 
331  // Integer corresponds to <s>.
333 
334  // Integer corresponds to </s>.
336 
337  // Integer corresponds to unknown-word. -1 if no unknown-word symbol is
338  // provided.
340 
341  // N-gram order of the language model.
343 
344  // Index of largest word-id plus one. It defines the end of <unigram_states_>
345  // array.
347 
348  // Number of entries in the overflow buffer for pointers that couldn't be
349  // represented as a 30-bit relative index.
351 
352  // Size of the <lm_states_> array, which will be needed by I/O.
354 
355  // Points to the end of <lm_states_>. We use this information to check if
356  // there is any illegal visit to the un-reserved memory.
358 
359  // Loopup table for pointers of unigrams. The pointer could be NULL, for
360  // example for those words that are in words.txt, but not in the language
361  // model.
363 
364  // Technically a 32-bit number cannot represent a possibly 64-bit pointer. We
365  // therefore use "relative" address instead of "absolute" address, which will
366  // be a small number most of the time. This buffer is for the case where the
367  // relative address has more than 30-bits.
369 
370  // Memory chunk that contains the actual LmStates. One LmState has the
371  // following structure:
372  //
373  // struct LmState {
374  // float logprob;
375  // float backoff_logprob;
376  // int32 num_children;
377  // std::pair<int32, int32> [] children;
378  // }
379  //
380  // Note that the floating point representation has 4 bytes, int32 also has 4
381  // bytes, therefore one LmState will occupy the following number of bytes:
382  //
383  // x = 1 + 1 + 1 + 2 * children.size() = 3 + 2 * children.size()
385 };
386 
392  : public fst::DeterministicOnDemandFst<fst::StdArc> {
393  public:
397 
398  explicit ConstArpaLmDeterministicFst(const ConstArpaLm& lm);
399 
400  // We cannot use "const" because the pure virtual function in the interface is
401  // not const.
402  virtual StateId Start() { return start_state_; }
403 
404  // We cannot use "const" because the pure virtual function in the interface is
405  // not const.
406  virtual Weight Final(StateId s);
407 
408  virtual bool GetArc(StateId s, Label ilabel, fst::StdArc* oarc);
409 
410  private:
411  typedef unordered_map<std::vector<Label>,
413  StateId start_state_;
414  MapType wseq_to_state_;
415  std::vector<std::vector<Label> > state_to_wseq_;
416  const ConstArpaLm& lm_;
417 };
418 
419 // Reads in an Arpa format language model and converts it into ConstArpaLm
420 // format. We assume that the words in the input Arpa format language model have
421 // been converted into integers.
422 bool BuildConstArpaLm(const ArpaParseOptions& options,
423  const std::string& arpa_rxfilename,
424  const std::string& const_arpa_wxfilename);
425 
426 } // namespace kaldi
427 
428 #endif // KALDI_LM_CONST_ARPA_LM_H_
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 UnkSymbol() const
std::vector< std::vector< Label > > state_to_wseq_
A hashing function-object for vectors.
Definition: stl-utils.h:216
Lattice::StateId StateId
fst::StdArc StdArc
Options that control ArpaFileParser.
float logprob
kaldi::int32 int32
Int32AndFloat(int32 input_i)
class DeterministicOnDemandFst is an "FST-like" base-class.
int32 ** unigram_states_
ConstArpaLm(const int32 bos_symbol, const int32 eos_symbol, const int32 unk_symbol, const int32 ngram_order, const int32 num_words, const int32 overflow_buffer_size, const int64 lm_states_size, int32 **unigram_states, int32 **overflow_buffer, int32 *lm_states)
fst::StdArc::Label Label
int32 ** overflow_buffer_
int32 NgramOrder() const
fst::StdArc::Weight Weight
Int32AndFloat(float input_f)
int32 BosSymbol() const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 EosSymbol() const
bool BuildConstArpaLm(const ArpaParseOptions &options, const std::string &arpa_rxfilename, const std::string &const_arpa_wxfilename)
unordered_map< std::vector< Label >, StateId, VectorHasher< Label > > MapType
This class wraps a ConstArpaLm format language model with the interface defined in DeterministicOnDem...