const-arpa-lm.cc
Go to the documentation of this file.
1 // lm/const-arpa-lm.cc
2 
3 // Copyright 2014 Guoguo Chen
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <algorithm>
21 #include <limits>
22 #include <sstream>
23 #include <utility>
24 
25 #include "base/kaldi-math.h"
26 #include "lm/arpa-file-parser.h"
27 #include "lm/const-arpa-lm.h"
28 #include "util/stl-utils.h"
29 #include "util/text-utils.h"
30 
31 
32 namespace kaldi {
33 
34 // Auxiliary struct for converting ConstArpaLm format langugae model to Arpa
35 // format.
36 struct ArpaLine {
37  std::vector<int32> words; // Sequence of words to be printed.
38  float logprob; // Logprob corresponds to word sequence.
39  float backoff_logprob; // Backoff_logprob corresponds to word sequence.
40  // Comparison function for sorting.
41  bool operator < (const ArpaLine &other) const {
42  if (words.size() < other.words.size()) {
43  return true;
44  } else if (words.size() > other.words.size()) {
45  return false;
46  } else {
47  return words < other.words;
48  }
49  }
50 };
51 
52 // Auxiliary class to build ConstArpaLm. We first use this class to figure out
53 // the relative address of different LmStates, and then put everything into one
54 // block in memory.
55 class LmState {
56  public:
57  union ChildType {
58  // If child is not the final order, we keep the pointer to its LmState.
60 
61  // If child is the final order, we only keep the log probability for it.
62  float prob;
63  };
64 
66  bool operator()(
67  const std::pair<int32, union ChildType>& lhs,
68  const std::pair<int32, union ChildType>& rhs) const {
69  return lhs.first < rhs.first;
70  }
71  };
72 
73  LmState(const bool is_unigram, const bool is_child_final_order,
74  const float logprob, const float backoff_logprob) :
75  is_unigram_(is_unigram), is_child_final_order_(is_child_final_order),
76  logprob_(logprob), backoff_logprob_(backoff_logprob) {}
77 
78  void SetMyAddress(const int64 address) {
79  my_address_ = address;
80  }
81 
82  void AddChild(const int32 word, LmState* child_state) {
83  KALDI_ASSERT(!is_child_final_order_);
84  ChildType child;
85  child.state = child_state;
86  children_.push_back(std::make_pair(word, child));
87  }
88 
89  void AddChild(const int32 word, const float child_prob) {
90  KALDI_ASSERT(is_child_final_order_);
91  ChildType child;
92  child.prob = child_prob;
93  children_.push_back(std::make_pair(word, child));
94  }
95 
96  int64 MyAddress() const {
97  return my_address_;
98  }
99 
100  bool IsUnigram() const {
101  return is_unigram_;
102  }
103 
104  bool IsChildFinalOrder() const {
105  return is_child_final_order_;
106  }
107 
108  float Logprob() const {
109  return logprob_;
110  }
111 
112  float BackoffLogprob() const {
113  return backoff_logprob_;
114  }
115 
116  int32 NumChildren() const {
117  return children_.size();
118  }
119 
120  std::pair<int32, union ChildType> GetChild(const int32 index) {
121  KALDI_ASSERT(index < children_.size());
122  KALDI_ASSERT(index >= 0);
123  return children_[index];
124  }
125 
126  void SortChildren() {
127  std::sort(children_.begin(), children_.end(), ChildrenVectorLessThan());
128  }
129 
130  // Checks if the current LmState is a leaf.
131  bool IsLeaf() const {
132  return (backoff_logprob_ == 0.0 && children_.empty());
133  }
134 
135  // Computes the size of the memory that the current LmState would take in
136  // <lm_states> array. It's the number of 4-byte chunks.
137  int32 MemSize() const {
138  if (IsLeaf() && !is_unigram_) {
139  // We don't create an entry in this case; the logprob will be stored in
140  // the same int32 that we would normally store the pointer in.
141  return 0;
142  } else {
143  // We store the following information:
144  // logprob, backoff_logprob, children.size() and children data.
145  return (3 + 2 * children_.size());
146  }
147  }
148 
149  private:
150  // Unigram states will have LmStates even if they are leaves, therefore we
151  // need to note when this is a unigram or not.
153 
154  // If the current LmState has an order of (final_order - 1), then its child
155  // must be the final order. We only keep the log probability for its child.
157 
158  // When we compute the addresses of the LmStates as offsets into <lm_states_>
159  // pointer, and put the offsets here. Note that this is just offset, not
160  // actual pointer.
161  int64 my_address_;
162 
163  // Language model log probability of the current sequence. For example, if
164  // this state is "A B", then it would be the logprob of "A -> B".
165  float logprob_;
166 
167  // Language model backoff log probability of the current sequence, e.g., state
168  // "A B -> X" backing off to "B -> X".
170 
171  // List of children.
172  std::vector<std::pair<int32, union ChildType> > children_;
173 };
174 
175 // Class to build ConstArpaLm from Arpa format language model. It relies on the
176 // auxiliary class LmState above.
178  public:
180  : ArpaFileParser(options, NULL) {
181  ngram_order_ = 0;
182  num_words_ = 0;
183  overflow_buffer_size_ = 0;
184  lm_states_size_ = 0;
185  max_address_offset_ = pow(2, 30) - 1;
186  is_built_ = false;
187  lm_states_ = NULL;
188  unigram_states_ = NULL;
189  overflow_buffer_ = NULL;
190  }
191 
193  unordered_map<std::vector<int32>,
194  LmState*, VectorHasher<int32> >::iterator iter;
195  for (iter = seq_to_state_.begin(); iter != seq_to_state_.end(); ++iter) {
196  delete iter->second;
197  }
198  if (is_built_) {
199  delete[] lm_states_;
200  delete[] unigram_states_;
201  delete[] overflow_buffer_;
202  }
203  }
204 
205  // Writes ConstArpaLm.
206  void Write(std::ostream &os, bool binary) const;
207 
208  void SetMaxAddressOffset(const int32 max_address_offset) {
209  KALDI_WARN << "You are changing <max_address_offset_>; the default should "
210  << "not be changed unless you are in testing mode.";
211  max_address_offset_ = max_address_offset;
212  }
213 
214  protected:
215  // ArpaFileParser overrides.
216  virtual void HeaderAvailable();
217  virtual void ConsumeNGram(const NGram& ngram);
218  virtual void ReadComplete();
219 
220  private:
223  const std::pair<std::vector<int32>*, LmState*>& lhs,
224  const std::pair<std::vector<int32>*, LmState*>& rhs) const {
225  return *(lhs.first) < *(rhs.first);
226  }
227  };
228 
229  private:
230  // Indicating if ConstArpaLm has been built or not.
231  bool is_built_;
232 
233  // Maximum relative address for the child. We put it here just for testing.
234  // The default value is 30-bits and should not be changed except for testing.
236 
237  // N-gram order of language model. This can be figured out from "/data/"
238  // section in Arpa format language model.
240 
241  // Index of largest word-id plus one. It defines the end of <unigram_states_>
242  // array.
244 
245  // Number of entries in the overflow buffer for pointers that couldn't be
246  // represented as a 30-bit relative index.
248 
249  // Size of the <lm_states_> array, which will be needed by I/O.
251 
252  // Memory blcok for storing LmStates.
254 
255  // Memory block for storing pointers of unigram LmStates.
257 
258  // Memory block for storing pointers of the LmStates that have large relative
259  // address to their parents.
261 
262  // Hash table from word sequences to LmStates.
263  unordered_map<std::vector<int32>,
265 };
266 
268  ngram_order_ = NgramCounts().size();
269 }
270 
272  int32 cur_order = ngram.words.size();
273  // If <ngram_order_> is larger than 1, then we do not create LmState for
274  // the final order entry. We only keep the log probability for it.
275  LmState *lm_state = NULL;
276  if (cur_order != ngram_order_ || ngram_order_ == 1) {
277  lm_state = new LmState(cur_order == 1,
278  cur_order == ngram_order_ - 1,
279  ngram.logprob, ngram.backoff);
280 
281  if (seq_to_state_.find(ngram.words) != seq_to_state_.end()) {
282  std::ostringstream os;
283  os << "[ ";
284  for (size_t i = 0; i < ngram.words.size(); i++) {
285  os << ngram.words[i] << " ";
286  }
287  os <<"]";
288 
289  KALDI_ERR << "N-gram " << os.str() << " appears twice in the arpa file";
290  }
291  seq_to_state_[ngram.words] = lm_state;
292  }
293 
294  // If n-gram order is larger than 1, we have to add possible child to
295  // existing LmStates. We have the following two assumptions:
296  // 1. N-grams are processed from small order to larger ones, i.e., from
297  // 1, 2, ... to the highest order.
298  // 2. If a n-gram exists in the Arpa format language model, then the
299  // "history" n-gram also exists. For example, if "A B C" is a valid
300  // n-gram, then "A B" is also a valid n-gram.
301  int32 last_word = ngram.words[cur_order - 1];
302  if (cur_order > 1) {
303  std::vector<int32> hist(ngram.words.begin(), ngram.words.end() - 1);
304  unordered_map<std::vector<int32>,
305  LmState*, VectorHasher<int32> >::iterator hist_iter;
306  hist_iter = seq_to_state_.find(hist);
307  if (hist_iter == seq_to_state_.end()) {
308  std::ostringstream ss;
309  for (int i = 0; i < cur_order; ++i)
310  ss << (i == 0 ? '[' : ' ') << ngram.words[i];
311  KALDI_ERR << "In line " << LineNumber() << ": "
312  << cur_order << "-gram " << ss.str() << "] does not have "
313  << "a parent model " << cur_order << "-gram.";
314  }
315  if (cur_order != ngram_order_ || ngram_order_ == 1) {
316  KALDI_ASSERT(lm_state != NULL);
317  KALDI_ASSERT(!hist_iter->second->IsChildFinalOrder());
318  hist_iter->second->AddChild(last_word, lm_state);
319  } else {
320  KALDI_ASSERT(lm_state == NULL);
321  KALDI_ASSERT(hist_iter->second->IsChildFinalOrder());
322  hist_iter->second->AddChild(last_word, ngram.logprob);
323  }
324  } else {
325  // Figures out <max_word_id>.
326  num_words_ = std::max(num_words_, last_word + 1);
327  }
328 }
329 
330 // ConstArpaLm can be built in the following steps, assuming we have already
331 // created LmStates <seq_to_state_>:
332 // 1. Sort LmStates lexicographically.
333 // This enables us to compute relative address. When we say lexicographic, we
334 // treat the word-ids as letters. After sorting, the LmStates are in the
335 // following order:
336 // ...
337 // A B
338 // A B A
339 // A B B
340 // A B C
341 // ...
342 // where each line represents a LmState.
343 // 2. Update <my_address> in LmState, which is relative to the first element in
344 // <sorted_vec>.
345 // 3. Put the following structure into the memory block
346 // struct LmState {
347 // float logprob;
348 // float backoff_logprob;
349 // int32 num_children;
350 // std::pair<int32, int32> [] children;
351 // }
352 //
353 // At the same time, we will also create two special buffers:
354 // <unigram_states_>
355 // <overflow_buffer_>
357  // STEP 1: sorting LmStates lexicographically.
358  // Vector for holding the sorted LmStates.
359  std::vector<std::pair<std::vector<int32>*, LmState*> > sorted_vec;
360  unordered_map<std::vector<int32>,
361  LmState*, VectorHasher<int32> >::iterator iter;
362  for (iter = seq_to_state_.begin(); iter != seq_to_state_.end(); ++iter) {
363  if (iter->second->MemSize() > 0) {
364  sorted_vec.push_back(
365  std::make_pair(const_cast<std::vector<int32>*>(&(iter->first)),
366  iter->second));
367  }
368  }
369 
370  std::sort(sorted_vec.begin(), sorted_vec.end(),
372 
373  // STEP 2: updating <my_address> in LmState.
374  for (int32 i = 0; i < sorted_vec.size(); ++i) {
375  lm_states_size_ += sorted_vec[i].second->MemSize();
376  if (i == 0) {
377  sorted_vec[i].second->SetMyAddress(0);
378  } else {
379  sorted_vec[i].second->SetMyAddress(sorted_vec[i - 1].second->MyAddress()
380  + sorted_vec[i - 1].second->MemSize());
381  }
382  }
383 
384  // STEP 3: creating memory block to store LmStates.
385  // Reserves a memory block for LmStates.
386  int64 lm_states_index = 0;
387  try {
388  lm_states_ = new int32[lm_states_size_];
389  } catch(const std::exception &e) {
390  KALDI_ERR << e.what();
391  }
392 
393  // Puts data into memory block.
394  unigram_states_ = new int32*[num_words_];
395  std::vector<int32*> overflow_buffer_vec;
396  for (int32 i = 0; i < num_words_; ++i) {
397  unigram_states_[i] = NULL;
398  }
399  for (int32 i = 0; i < sorted_vec.size(); ++i) {
400  // Current address.
401  int32* parent_address = lm_states_ + lm_states_index;
402 
403  // Adds logprob.
404  Int32AndFloat logprob_f(sorted_vec[i].second->Logprob());
405  lm_states_[lm_states_index++] = logprob_f.i;
406 
407  // Adds backoff_logprob.
408  Int32AndFloat backoff_logprob_f(sorted_vec[i].second->BackoffLogprob());
409  lm_states_[lm_states_index++] = backoff_logprob_f.i;
410 
411  // Adds num_children.
412  lm_states_[lm_states_index++] = sorted_vec[i].second->NumChildren();
413 
414  // Adds children, there are 3 cases:
415  // 1. Child is a leaf and not unigram
416  // 2. Child is not a leaf or is unigram
417  // 2.1 Relative address can be represented by 30 bits
418  // 2.2 Relative address cannot be represented by 30 bits
419  sorted_vec[i].second->SortChildren();
420  for (int32 j = 0; j < sorted_vec[i].second->NumChildren(); ++j) {
421  int32 child_info;
422  if (sorted_vec[i].second->IsChildFinalOrder() ||
423  sorted_vec[i].second->GetChild(j).second.state->MemSize() == 0) {
424  // Child is a leaf and not unigram. In this case we will not create an
425  // entry in <lm_states_>; instead, we put the logprob in the place where
426  // we normally store the poitner.
427  Int32AndFloat child_logprob_f;
428  if (sorted_vec[i].second->IsChildFinalOrder()) {
429  child_logprob_f.f = sorted_vec[i].second->GetChild(j).second.prob;
430  } else {
431  child_logprob_f.f =
432  sorted_vec[i].second->GetChild(j).second.state->Logprob();
433  }
434  child_info = child_logprob_f.i;
435  child_info &= ~1; // Sets the last bit to 0 so <child_info> is even.
436  } else {
437  // Child is not a leaf or is unigram.
438  int64 offset =
439  sorted_vec[i].second->GetChild(j).second.state->MyAddress()
440  - sorted_vec[i].second->MyAddress();
441  KALDI_ASSERT(offset > 0);
442  if (offset <= max_address_offset_) {
443  // Relative address can be represented by 30 bits.
444  child_info = offset * 2;
445  child_info |= 1;
446  } else {
447  // Relative address cannot be represented by 30 bits, we have to put
448  // the child address into <overflow_buffer_>.
449  int32* abs_address = parent_address + offset;
450  overflow_buffer_vec.push_back(abs_address);
451  int32 overflow_buffer_index = overflow_buffer_vec.size() - 1;
452  child_info = overflow_buffer_index * 2;
453  child_info |= 1;
454  child_info *= -1;
455  }
456  }
457  // Child word.
458  lm_states_[lm_states_index++] = sorted_vec[i].second->GetChild(j).first;
459  // Child info.
460  lm_states_[lm_states_index++] = child_info;
461  }
462 
463  // If the current state corresponds to an unigram, then create a separate
464  // loop up table to improve efficiency, since those will be looked up pretty
465  // frequently.
466  if (sorted_vec[i].second->IsUnigram()) {
467  KALDI_ASSERT(sorted_vec[i].first->size() == 1);
468  unigram_states_[(*sorted_vec[i].first)[0]] = parent_address;
469  }
470  }
471  KALDI_ASSERT(lm_states_size_ == lm_states_index);
472 
473  // Move <overflow_buffer_> from vector holder to array.
474  overflow_buffer_size_ = overflow_buffer_vec.size();
475  overflow_buffer_ = new int32*[overflow_buffer_size_];
476  for (int32 i = 0; i < overflow_buffer_size_; ++i) {
477  overflow_buffer_[i] = overflow_buffer_vec[i];
478  }
479 
480  is_built_ = true;
481 }
482 
483 void ConstArpaLmBuilder::Write(std::ostream &os, bool binary) const {
484  if (!binary) {
485  KALDI_ERR << "text-mode writing is not implemented for ConstArpaLmBuilder.";
486  }
487  KALDI_ASSERT(is_built_);
488 
489  // Creates ConstArpaLm.
490  ConstArpaLm const_arpa_lm(
491  Options().bos_symbol, Options().eos_symbol, Options().unk_symbol,
492  ngram_order_, num_words_, overflow_buffer_size_, lm_states_size_,
493  unigram_states_, overflow_buffer_, lm_states_);
494  const_arpa_lm.Write(os, binary);
495 }
496 
497 void ConstArpaLm::Write(std::ostream &os, bool binary) const {
498  KALDI_ASSERT(initialized_);
499  if (!binary) {
500  KALDI_ERR << "text-mode writing is not implemented for ConstArpaLm.";
501  }
502 
503  WriteToken(os, binary, "<ConstArpaLm>");
504 
505  // Misc info.
506  WriteToken(os, binary, "<LmInfo>");
507  WriteBasicType(os, binary, bos_symbol_);
508  WriteBasicType(os, binary, eos_symbol_);
509  WriteBasicType(os, binary, unk_symbol_);
510  WriteBasicType(os, binary, ngram_order_);
511  WriteToken(os, binary, "</LmInfo>");
512 
513  // LmStates section.
514  WriteToken(os, binary, "<LmStates>");
515  WriteBasicType(os, binary, lm_states_size_);
516  os.write(reinterpret_cast<char *>(lm_states_),
517  sizeof(int32) * lm_states_size_);
518  if (!os.good()) {
519  KALDI_ERR << "ConstArpaLm <LmStates> section writing failed.";
520  }
521  WriteToken(os, binary, "</LmStates>");
522 
523  // Unigram section. We write memory offset to disk instead of the absolute
524  // pointers.
525  WriteToken(os, binary, "<LmUnigram>");
526  WriteBasicType(os, binary, num_words_);
527  int64* tmp_unigram_address = new int64[num_words_];
528  for (int32 i = 0; i < num_words_; ++i) {
529  // The relative address here is a little bit tricky:
530  // 1. If the original address is NULL, then we set the relative address to
531  // zero.
532  // 2. If the original address is not NULL, we set it to the following:
533  // unigram_states_[i] - lm_states_ + 1
534  // we plus 1 to ensure that the above value is positive.
535  tmp_unigram_address[i] = (unigram_states_[i] == NULL) ? 0 :
536  unigram_states_[i] - lm_states_ + 1;
537  }
538  os.write(reinterpret_cast<char *>(tmp_unigram_address),
539  sizeof(int64) * num_words_);
540  if (!os.good()) {
541  KALDI_ERR << "ConstArpaLm <LmUnigram> section writing failed.";
542  }
543  delete[] tmp_unigram_address; // Releases the memory.
544  tmp_unigram_address = NULL;
545  WriteToken(os, binary, "</LmUnigram>");
546 
547  // Overflow section. We write memory offset to disk instead of the absolute
548  // pointers.
549  WriteToken(os, binary, "<LmOverflow>");
550  WriteBasicType(os, binary, overflow_buffer_size_);
551  int64* tmp_overflow_address = new int64[overflow_buffer_size_];
552  for (int32 i = 0; i < overflow_buffer_size_; ++i) {
553  // The relative address here is a little bit tricky:
554  // 1. If the original address is NULL, then we set the relative address to
555  // zero.
556  // 2. If the original address is not NULL, we set it to the following:
557  // overflow_buffer_[i] - lm_states_ + 1
558  // we plus 1 to ensure that the above value is positive.
559  tmp_overflow_address[i] = (overflow_buffer_[i] == NULL) ? 0 :
560  overflow_buffer_[i] - lm_states_ + 1;
561  }
562  os.write(reinterpret_cast<char *>(tmp_overflow_address),
563  sizeof(int64) * overflow_buffer_size_);
564  if (!os.good()) {
565  KALDI_ERR << "ConstArpaLm <LmOverflow> section writing failed.";
566  }
567  delete[] tmp_overflow_address;
568  tmp_overflow_address = NULL;
569  WriteToken(os, binary, "</LmOverflow>");
570  WriteToken(os, binary, "</ConstArpaLm>");
571 }
572 
573 void ConstArpaLm::Read(std::istream &is, bool binary) {
574  KALDI_ASSERT(!initialized_);
575  if (!binary) {
576  KALDI_ERR << "text-mode reading is not implemented for ConstArpaLm.";
577  }
578 
579  int first_char = is.peek();
580  if (first_char == 4) { // Old on-disk format starts with length of int32.
581  ReadInternalOldFormat(is, binary);
582  } else { // New on-disk format starts with token <ConstArpaLm>.
583  ReadInternal(is, binary);
584  }
585 }
586 
587 void ConstArpaLm::ReadInternal(std::istream &is, bool binary) {
588  KALDI_ASSERT(!initialized_);
589  if (!binary) {
590  KALDI_ERR << "text-mode reading is not implemented for ConstArpaLm.";
591  }
592 
593  ExpectToken(is, binary, "<ConstArpaLm>");
594 
595  // Misc info.
596  ExpectToken(is, binary, "<LmInfo>");
597  ReadBasicType(is, binary, &bos_symbol_);
598  ReadBasicType(is, binary, &eos_symbol_);
599  ReadBasicType(is, binary, &unk_symbol_);
600  ReadBasicType(is, binary, &ngram_order_);
601  ExpectToken(is, binary, "</LmInfo>");
602 
603  // LmStates section.
604  ExpectToken(is, binary, "<LmStates>");
605  ReadBasicType(is, binary, &lm_states_size_);
606  lm_states_ = new int32[lm_states_size_];
607  is.read(reinterpret_cast<char *>(lm_states_),
608  sizeof(int32) * lm_states_size_);
609  if (!is.good()) {
610  KALDI_ERR << "ConstArpaLm <LmStates> section reading failed.";
611  }
612  ExpectToken(is, binary, "</LmStates>");
613 
614  // Unigram section. We write memory offset to disk instead of the absolute
615  // pointers.
616  ExpectToken(is, binary, "<LmUnigram>");
617  ReadBasicType(is, binary, &num_words_);
618  unigram_states_ = new int32*[num_words_];
619  int64* tmp_unigram_address = new int64[num_words_];
620  is.read(reinterpret_cast<char *>(tmp_unigram_address),
621  sizeof(int64) * num_words_);
622  if (!is.good()) {
623  KALDI_ERR << "ConstArpaLm <LmUnigram> section reading failed.";
624  }
625  for (int32 i = 0; i < num_words_; ++i) {
626  // Check out how we compute the relative address in ConstArpaLm::Write().
627  unigram_states_[i] = (tmp_unigram_address[i] == 0) ? NULL
628  : lm_states_ + tmp_unigram_address[i] - 1;
629  }
630  delete[] tmp_unigram_address;
631  tmp_unigram_address = NULL;
632  ExpectToken(is, binary, "</LmUnigram>");
633 
634  // Overflow section. We write memory offset to disk instead of the absolute
635  // pointers.
636  ExpectToken(is, binary, "<LmOverflow>");
637  ReadBasicType(is, binary, &overflow_buffer_size_);
638  overflow_buffer_ = new int32*[overflow_buffer_size_];
639  int64* tmp_overflow_address = new int64[overflow_buffer_size_];
640  is.read(reinterpret_cast<char *>(tmp_overflow_address),
641  sizeof(int64) * overflow_buffer_size_);
642  if (!is.good()) {
643  KALDI_ERR << "ConstArpaLm <LmOverflow> section reading failed.";
644  }
645  for (int32 i = 0; i < overflow_buffer_size_; ++i) {
646  // Check out how we compute the relative address in ConstArpaLm::Write().
647  overflow_buffer_[i] = (tmp_overflow_address[i] == 0) ? NULL
648  : lm_states_ + tmp_overflow_address[i] - 1;
649  }
650  delete[] tmp_overflow_address;
651  tmp_overflow_address = NULL;
652  ExpectToken(is, binary, "</LmOverflow>");
653  ExpectToken(is, binary, "</ConstArpaLm>");
654 
655  KALDI_ASSERT(ngram_order_ > 0);
656  KALDI_ASSERT(bos_symbol_ < num_words_ && bos_symbol_ > 0);
657  KALDI_ASSERT(eos_symbol_ < num_words_ && eos_symbol_ > 0);
658  KALDI_ASSERT(unk_symbol_ < num_words_ &&
659  (unk_symbol_ > 0 || unk_symbol_ == -1));
660  lm_states_end_ = lm_states_ + lm_states_size_ - 1;
661  memory_assigned_ = true;
662  initialized_ = true;
663 }
664 
665 void ConstArpaLm::ReadInternalOldFormat(std::istream &is, bool binary) {
666  KALDI_ASSERT(!initialized_);
667  if (!binary) {
668  KALDI_ERR << "text-mode reading is not implemented for ConstArpaLm.";
669  }
670 
671  // Misc info.
672  ReadBasicType(is, binary, &bos_symbol_);
673  ReadBasicType(is, binary, &eos_symbol_);
674  ReadBasicType(is, binary, &unk_symbol_);
675  ReadBasicType(is, binary, &ngram_order_);
676 
677  // LmStates section.
678  // In the deprecated version, <lm_states_size_> used to be type of int32,
679  // which was a bug. We therefore use int32 for read for back-compatibility.
680  int32 lm_states_size_int32;
681  ReadBasicType(is, binary, &lm_states_size_int32);
682  lm_states_size_ = static_cast<int64>(lm_states_size_int32);
683  lm_states_ = new int32[lm_states_size_];
684  for (int64 i = 0; i < lm_states_size_; ++i) {
685  ReadBasicType(is, binary, &lm_states_[i]);
686  }
687 
688  // Unigram section. We write memory offset to disk instead of the absolute
689  // pointers.
690  ReadBasicType(is, binary, &num_words_);
691  unigram_states_ = new int32*[num_words_];
692  for (int32 i = 0; i < num_words_; ++i) {
693  int64 tmp_address;
694  ReadBasicType(is, binary, &tmp_address);
695  // Check out how we compute the relative address in ConstArpaLm::Write().
696  unigram_states_[i] =
697  (tmp_address == 0) ? NULL : lm_states_ + tmp_address - 1;
698  }
699 
700  // Overflow section. We write memory offset to disk instead of the absolute
701  // pointers.
702  ReadBasicType(is, binary, &overflow_buffer_size_);
703  overflow_buffer_ = new int32*[overflow_buffer_size_];
704  for (int32 i = 0; i < overflow_buffer_size_; ++i) {
705  int64 tmp_address;
706  ReadBasicType(is, binary, &tmp_address);
707  // Check out how we compute the relative address in ConstArpaLm::Write().
708  overflow_buffer_[i] =
709  (tmp_address == 0) ? NULL : lm_states_ + tmp_address - 1;
710  }
711  KALDI_ASSERT(ngram_order_ > 0);
712  KALDI_ASSERT(bos_symbol_ < num_words_ && bos_symbol_ > 0);
713  KALDI_ASSERT(eos_symbol_ < num_words_ && eos_symbol_ > 0);
714  KALDI_ASSERT(unk_symbol_ < num_words_ &&
715  (unk_symbol_ > 0 || unk_symbol_ == -1));
716  lm_states_end_ = lm_states_ + lm_states_size_ - 1;
717  memory_assigned_ = true;
718  initialized_ = true;
719 }
720 
721 bool ConstArpaLm::HistoryStateExists(const std::vector<int32>& hist) const {
722  // We do not create LmState for empty word sequence, but technically it is the
723  // history state of all unigrams.
724  if (hist.size() == 0) {
725  return true;
726  }
727 
728  // Tries to locate the LmState of the given word sequence.
729  int32* lm_state = GetLmState(hist);
730  if (lm_state == NULL) {
731  // <lm_state> does not exist means <hist> has no child.
732  return false;
733  } else {
734  // Note that we always create LmState for unigrams, so even if <lm_state> is
735  // not NULL, we still have to check if it has child.
736  KALDI_ASSERT(lm_state >= lm_states_);
737  KALDI_ASSERT(lm_state + 2 <= lm_states_end_);
738  // <lm_state + 2> points to <num_children>.
739  if (*(lm_state + 2) > 0) {
740  return true;
741  } else {
742  return false;
743  }
744  }
745  return true;
746 }
747 
749  const std::vector<int32>& hist) const {
750  KALDI_ASSERT(initialized_);
751 
752  // If the history size plus one is larger than <ngram_order_>, remove the old
753  // words.
754  std::vector<int32> mapped_hist(hist);
755  while (mapped_hist.size() >= ngram_order_) {
756  mapped_hist.erase(mapped_hist.begin(), mapped_hist.begin() + 1);
757  }
758  KALDI_ASSERT(mapped_hist.size() + 1 <= ngram_order_);
759 
760  // TODO(guoguo): check with Dan if this is reasonable.
761  // Maps possible out-of-vocabulary words to <unk>. If a word does not have a
762  // corresponding LmState, we treat it as <unk>. We map it to <unk> if <unk> is
763  // specified.
764  int32 mapped_word = word;
765  if (unk_symbol_ != -1) {
766  KALDI_ASSERT(mapped_word >= 0);
767  if (mapped_word >= num_words_ || unigram_states_[mapped_word] == NULL) {
768  mapped_word = unk_symbol_;
769  }
770  for (int32 i = 0; i < mapped_hist.size(); ++i) {
771  KALDI_ASSERT(mapped_hist[i] >= 0);
772  if (mapped_hist[i] >= num_words_ ||
773  unigram_states_[mapped_hist[i]] == NULL) {
774  mapped_hist[i] = unk_symbol_;
775  }
776  }
777  }
778 
779  // Loops up n-gram probability.
780  return GetNgramLogprobRecurse(mapped_word, mapped_hist);
781 }
782 
784  const int32 word, const std::vector<int32>& hist) const {
785  KALDI_ASSERT(initialized_);
786  KALDI_ASSERT(hist.size() + 1 <= ngram_order_);
787 
788  // Unigram case.
789  if (hist.size() == 0) {
790  if (word >= num_words_ || unigram_states_[word] == NULL) {
791  // If <unk> is defined, then the word sequence should have already been
792  // mapped to <unk> is necessary; this is for the case where <unk> is not
793  // defined.
794  return std::numeric_limits<float>::min();
795  } else {
796  Int32AndFloat logprob_i(*unigram_states_[word]);
797  return logprob_i.f;
798  }
799  }
800 
801  // High n-gram orders.
802  float logprob = 0.0;
803  float backoff_logprob = 0.0;
804  int32* state;
805  if ((state = GetLmState(hist)) != NULL) {
806  int32 child_info;
807  int32* child_lm_state = NULL;
808  if (GetChildInfo(word, state, &child_info)) {
809  DecodeChildInfo(child_info, state, &child_lm_state, &logprob);
810  return logprob;
811  } else {
812  Int32AndFloat backoff_logprob_i(*(state + 1));
813  backoff_logprob = backoff_logprob_i.f;
814  }
815  }
816  std::vector<int32> new_hist(hist);
817  new_hist.erase(new_hist.begin(), new_hist.begin() + 1);
818  return backoff_logprob + GetNgramLogprobRecurse(word, new_hist);
819 }
820 
821 int32* ConstArpaLm::GetLmState(const std::vector<int32>& seq) const {
822  KALDI_ASSERT(initialized_);
823 
824  // No LmState exists for empty word sequence.
825  if (seq.size() == 0) return NULL;
826 
827  // If <unk> is defined, then the word sequence should have already been mapped
828  // to <unk> is necessary; this is for the case where <unk> is not defined.
829  if (seq[0] >= num_words_ || unigram_states_[seq[0]] == NULL) return NULL;
830  int32* parent = unigram_states_[seq[0]];
831 
832  int32 child_info;
833  int32* child_lm_state = NULL;
834  float logprob;
835  for (int32 i = 1; i < seq.size(); ++i) {
836  if (!GetChildInfo(seq[i], parent, &child_info)) {
837  return NULL;
838  }
839  DecodeChildInfo(child_info, parent, &child_lm_state, &logprob);
840  if (child_lm_state == NULL) {
841  return NULL;
842  } else {
843  parent = child_lm_state;
844  }
845  }
846  return parent;
847 }
848 
850  int32* parent, int32* child_info) const {
851  KALDI_ASSERT(initialized_);
852 
853  KALDI_ASSERT(parent != NULL);
854  KALDI_ASSERT(parent >= lm_states_);
855  KALDI_ASSERT(child_info != NULL);
856 
857  KALDI_ASSERT(parent + 2 <= lm_states_end_);
858  int32 num_children = *(parent + 2);
859  KALDI_ASSERT(parent + 2 + 2 * num_children <= lm_states_end_);
860 
861  if (num_children == 0) return false;
862 
863  // A binary search into the children memory block.
864  int32 start_index = 1;
865  int32 end_index = num_children;
866  while (start_index <= end_index) {
867  int32 mid_index = round((start_index + end_index) / 2);
868  int32 mid_word = *(parent + 1 + 2 * mid_index);
869  if (mid_word == word) {
870  *child_info = *(parent + 2 + 2 * mid_index);
871  return true;
872  } else if (mid_word < word) {
873  start_index = mid_index + 1;
874  } else {
875  end_index = mid_index - 1;
876  }
877  }
878 
879  return false;
880 }
881 
882 void ConstArpaLm::DecodeChildInfo(const int32 child_info,
883  int32* parent,
884  int32** child_lm_state,
885  float* logprob) const {
886  KALDI_ASSERT(initialized_);
887 
888  KALDI_ASSERT(logprob != NULL);
889  if (child_info % 2 == 0) {
890  // Child is a leaf, only returns the log probability.
891  *child_lm_state = NULL;
892  Int32AndFloat logprob_i(child_info);
893  *logprob = logprob_i.f;
894  } else {
895  int32 child_offset = child_info / 2;
896  if (child_offset > 0) {
897  *child_lm_state = parent + child_offset;
898  Int32AndFloat logprob_i(**child_lm_state);
899  *logprob = logprob_i.f;
900  } else {
901  KALDI_ASSERT(-child_offset < overflow_buffer_size_);
902  *child_lm_state = overflow_buffer_[-child_offset];
903  Int32AndFloat logprob_i(**child_lm_state);
904  *logprob = logprob_i.f;
905  }
906  KALDI_ASSERT(*child_lm_state >= lm_states_);
907  KALDI_ASSERT(*child_lm_state <= lm_states_end_);
908  }
909 }
910 
912  const std::vector<int32>& seq,
913  std::vector<ArpaLine> *output) const {
914  if (lm_state == NULL) return;
915 
916  KALDI_ASSERT(lm_state >= lm_states_);
917  KALDI_ASSERT(lm_state + 2 <= lm_states_end_);
918 
919  // Inserts the current LmState to <output>.
920  ArpaLine arpa_line;
921  arpa_line.words = seq;
922  Int32AndFloat logprob_i(*lm_state);
923  arpa_line.logprob = logprob_i.f;
924  Int32AndFloat backoff_logprob_i(*(lm_state + 1));
925  arpa_line.backoff_logprob = backoff_logprob_i.f;
926  output->push_back(arpa_line);
927 
928  // Scans for possible children, and recursively adds child to <output>.
929  int32 num_children = *(lm_state + 2);
930  KALDI_ASSERT(lm_state + 2 + 2 * num_children <= lm_states_end_);
931  for (int32 i = 0; i < num_children; ++i) {
932  std::vector<int32> new_seq(seq);
933  new_seq.push_back(*(lm_state + 3 + 2 * i));
934  int32 child_info = *(lm_state + 4 + 2 * i);
935  float logprob;
936  int32* child_lm_state = NULL;
937  DecodeChildInfo(child_info, lm_state, &child_lm_state, &logprob);
938 
939  if (child_lm_state == NULL) {
940  // Leaf case.
941  ArpaLine child_arpa_line;
942  child_arpa_line.words = new_seq;
943  child_arpa_line.logprob = logprob;
944  child_arpa_line.backoff_logprob = 0.0;
945  output->push_back(child_arpa_line);
946  } else {
947  WriteArpaRecurse(child_lm_state, new_seq, output);
948  }
949  }
950 }
951 
952 void ConstArpaLm::WriteArpa(std::ostream &os) const {
953  KALDI_ASSERT(initialized_);
954 
955  std::vector<ArpaLine> tmp_output;
956  for (int32 i = 0; i < num_words_; ++i) {
957  if (unigram_states_[i] != NULL) {
958  std::vector<int32> seq(1, i);
959  WriteArpaRecurse(unigram_states_[i], seq, &tmp_output);
960  }
961  }
962 
963  // Sorts ArpaLines and collects head information.
964  std::sort(tmp_output.begin(), tmp_output.end());
965  std::vector<int32> ngram_count(1, 0);
966  for (int32 i = 0; i < tmp_output.size(); ++i) {
967  if (tmp_output[i].words.size() >= ngram_count.size()) {
968  ngram_count.resize(tmp_output[i].words.size() + 1);
969  ngram_count[tmp_output[i].words.size()] = 1;
970  } else {
971  ngram_count[tmp_output[i].words.size()] += 1;
972  }
973  }
974 
975  // Writes the header.
976  os << std::endl;
977  os << "\\data\\" << std::endl;
978  for (int32 i = 1; i < ngram_count.size(); ++i) {
979  os << "ngram " << i << "=" << ngram_count[i] << std::endl;
980  }
981 
982  // Writes n-grams.
983  int32 current_order = 0;
984  for (int32 i = 0; i < tmp_output.size(); ++i) {
985  // Beginning of a n-gram section.
986  if (tmp_output[i].words.size() != current_order) {
987  current_order = tmp_output[i].words.size();
988  os << std::endl;
989  os << "\\" << current_order << "-grams:" << std::endl;
990  }
991 
992  // Writes logprob.
993  os << tmp_output[i].logprob << '\t';
994 
995  // Writes word sequence.
996  for (int32 j = 0; j < tmp_output[i].words.size(); ++j) {
997  os << tmp_output[i].words[j];
998  if (j != tmp_output[i].words.size() - 1) {
999  os << " ";
1000  }
1001  }
1002 
1003  // Writes backoff_logprob if it is not zero.
1004  if (tmp_output[i].backoff_logprob != 0.0) {
1005  os << '\t' << tmp_output[i].backoff_logprob;
1006  }
1007  os << std::endl;
1008  }
1009 
1010  os << std::endl << "\\end\\" << std::endl;
1011 }
1012 
1014  const ConstArpaLm& lm) : lm_(lm) {
1015  // Creates a history state for <s>.
1016  std::vector<Label> bos_state(1, lm_.BosSymbol());
1017  state_to_wseq_.push_back(bos_state);
1018  wseq_to_state_[bos_state] = 0;
1019  start_state_ = 0;
1020 }
1021 
1023  // At this point, we should have created the state.
1024  KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
1025  const std::vector<Label>& wseq = state_to_wseq_[s];
1026  float logprob = lm_.GetNgramLogprob(lm_.EosSymbol(), wseq);
1027  return Weight(-logprob);
1028 }
1029 
1031  Label ilabel, fst::StdArc *oarc) {
1032  // At this point, we should have created the state.
1033  KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
1034  std::vector<Label> wseq = state_to_wseq_[s];
1035 
1036  float logprob = lm_.GetNgramLogprob(ilabel, wseq);
1037  if (logprob == std::numeric_limits<float>::min()) {
1038  return false;
1039  }
1040 
1041  // Locates the next state in ConstArpaLm. Note that OOV and backoff have been
1042  // taken care of in ConstArpaLm.
1043  wseq.push_back(ilabel);
1044  while (wseq.size() >= lm_.NgramOrder()) {
1045  // History state has at most lm_.NgramOrder() -1 words in the state.
1046  wseq.erase(wseq.begin(), wseq.begin() + 1);
1047  }
1048  while (!lm_.HistoryStateExists(wseq)) {
1049  KALDI_ASSERT(wseq.size() > 0);
1050  wseq.erase(wseq.begin(), wseq.begin() + 1);
1051  }
1052 
1053  std::pair<const std::vector<Label>, StateId> wseq_state_pair(
1054  wseq, static_cast<Label>(state_to_wseq_.size()));
1055 
1056  // Attemps to insert the current <wseq_state_pair>. If the pair already exists
1057  // then it returns false.
1058  typedef MapType::iterator IterType;
1059  std::pair<IterType, bool> result = wseq_to_state_.insert(wseq_state_pair);
1060 
1061  // If the pair was just inserted, then also add it to <state_to_wseq_>.
1062  if (result.second == true)
1063  state_to_wseq_.push_back(wseq);
1064 
1065  // Creates the arc.
1066  oarc->ilabel = ilabel;
1067  oarc->olabel = ilabel;
1068  oarc->nextstate = result.first->second;
1069  oarc->weight = Weight(-logprob);
1070 
1071  return true;
1072 }
1073 
1075  const std::string& arpa_rxfilename,
1076  const std::string& const_arpa_wxfilename) {
1077  ConstArpaLmBuilder lm_builder(options);
1078  KALDI_LOG << "Reading " << arpa_rxfilename;
1079  Input ki(arpa_rxfilename);
1080  lm_builder.Read(ki.Stream());
1081  WriteKaldiObject(lm_builder, const_arpa_wxfilename, true);
1082  return true;
1083 }
1084 
1085 } // namespace kaldi
ArpaFileParser is an abstract base class for ARPA LM file conversion.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool is_child_final_order_
std::vector< std::vector< Label > > state_to_wseq_
void Write(std::ostream &os, bool binary) const
A hashing function-object for vectors.
Definition: stl-utils.h:216
void Read(std::istream &is, bool binary)
unordered_map< std::vector< int32 >, LmState *, VectorHasher< int32 > > seq_to_state_
bool operator()(const std::pair< std::vector< int32 > *, LmState *> &lhs, const std::pair< std::vector< int32 > *, LmState *> &rhs) const
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
virtual void ReadComplete()
Override function called after the last n-gram has been consumed.
fst::StdArc StdArc
Options that control ArpaFileParser.
float logprob
Log-prob of the n-gram.
int64 MyAddress() const
void ReadInternalOldFormat(std::istream &is, bool binary)
float logprob
kaldi::int32 int32
LmState(const bool is_unigram, const bool is_child_final_order, const float logprob, const float backoff_logprob)
void SetMaxAddressOffset(const int32 max_address_offset)
virtual bool GetArc(StateId s, Label ilabel, fst::StdArc *oarc)
ConstArpaLmDeterministicFst(const ConstArpaLm &lm)
void AddChild(const int32 word, LmState *child_state)
void Write(std::ostream &os, bool binary) const
std::istream & Stream()
Definition: kaldi-io.cc:826
virtual void HeaderAvailable()
Override function called to signal that ARPA header with the expected number of n-grams has been read...
float backoff
log-backoff weight of the n-gram.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
void WriteArpaRecurse(int32 *lm_state, const std::vector< int32 > &seq, std::vector< ArpaLine > *output) const
void DecodeChildInfo(const int32 child_info, int32 *parent, int32 **child_lm_state, float *logprob) const
std::vector< int32 > words
Symbols in left to right order.
void SetMyAddress(const int64 address)
#define KALDI_ERR
Definition: kaldi-error.h:147
std::pair< int32, union ChildType > GetChild(const int32 index)
#define KALDI_WARN
Definition: kaldi-error.h:150
void Read(std::istream &is)
Read ARPA LM file from a stream.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 * GetLmState(const std::vector< int32 > &seq) const
int32 NgramOrder() const
fst::StdArc::Weight Weight
std::vector< std::pair< int32, union ChildType > > children_
void AddChild(const int32 word, const float child_prob)
int32 BosSymbol() const
bool IsUnigram() const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
float GetNgramLogprob(const int32 word, const std::vector< int32 > &hist) const
bool operator<(const ArpaLine &other) const
bool operator()(const std::pair< int32, union ChildType > &lhs, const std::pair< int32, union ChildType > &rhs) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
bool IsLeaf() const
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
int32 EosSymbol() const
int32 MemSize() const
float GetNgramLogprobRecurse(const int32 word, const std::vector< int32 > &hist) const
bool BuildConstArpaLm(const ArpaParseOptions &options, const std::string &arpa_rxfilename, const std::string &const_arpa_wxfilename)
virtual void ConsumeNGram(const NGram &ngram)
Pure override that must be implemented to process current n-gram.
virtual Weight Final(StateId s)
std::vector< int32 > words
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 NumChildren() const
A parsed n-gram from ARPA LM file.
ConstArpaLmBuilder(ArpaParseOptions options)
bool IsChildFinalOrder() const
void ReadInternal(std::istream &is, bool binary)
bool HistoryStateExists(const std::vector< int32 > &hist) const
float BackoffLogprob() const
void WriteArpa(std::ostream &os) const
bool GetChildInfo(const int32 word, int32 *parent, int32 *child_info) const
float Logprob() const