42 if (words.size() < other.
words.size()) {
44 }
else if (words.size() > other.
words.size()) {
47 return words < other.
words;
67 const std::pair<int32, union ChildType>& lhs,
68 const std::pair<int32, union ChildType>& rhs)
const {
69 return lhs.first < rhs.first;
73 LmState(
const bool is_unigram,
const bool is_child_final_order,
75 is_unigram_(is_unigram), is_child_final_order_(is_child_final_order),
76 logprob_(logprob), backoff_logprob_(backoff_logprob) {}
79 my_address_ = address;
85 child.
state = child_state;
86 children_.push_back(std::make_pair(word, child));
92 child.
prob = child_prob;
93 children_.push_back(std::make_pair(word, child));
105 return is_child_final_order_;
113 return backoff_logprob_;
117 return children_.size();
123 return children_[index];
132 return (backoff_logprob_ == 0.0 && children_.empty());
138 if (IsLeaf() && !is_unigram_) {
145 return (3 + 2 * children_.size());
172 std::vector<std::pair<int32, union ChildType> >
children_;
183 overflow_buffer_size_ = 0;
185 max_address_offset_ = pow(2, 30) - 1;
188 unigram_states_ = NULL;
189 overflow_buffer_ = NULL;
193 unordered_map<std::vector<int32>,
195 for (iter = seq_to_state_.begin(); iter != seq_to_state_.end(); ++iter) {
200 delete[] unigram_states_;
201 delete[] overflow_buffer_;
206 void Write(std::ostream &os,
bool binary)
const;
209 KALDI_WARN <<
"You are changing <max_address_offset_>; the default should " 210 <<
"not be changed unless you are in testing mode.";
211 max_address_offset_ = max_address_offset;
216 virtual void HeaderAvailable();
217 virtual void ConsumeNGram(
const NGram& ngram);
218 virtual void ReadComplete();
223 const std::pair<std::vector<int32>*,
LmState*>& lhs,
224 const std::pair<std::vector<int32>*,
LmState*>& rhs)
const {
225 return *(lhs.first) < *(rhs.first);
263 unordered_map<std::vector<int32>,
268 ngram_order_ = NgramCounts().size();
276 if (cur_order != ngram_order_ || ngram_order_ == 1) {
277 lm_state =
new LmState(cur_order == 1,
278 cur_order == ngram_order_ - 1,
281 if (seq_to_state_.find(ngram.
words) != seq_to_state_.end()) {
282 std::ostringstream os;
284 for (
size_t i = 0;
i < ngram.
words.size();
i++) {
285 os << ngram.
words[
i] <<
" ";
289 KALDI_ERR <<
"N-gram " << os.str() <<
" appears twice in the arpa file";
291 seq_to_state_[ngram.
words] = lm_state;
303 std::vector<int32> hist(ngram.
words.begin(), ngram.
words.end() - 1);
304 unordered_map<std::vector<int32>,
306 hist_iter = seq_to_state_.find(hist);
307 if (hist_iter == seq_to_state_.end()) {
308 std::ostringstream ss;
309 for (
int i = 0;
i < cur_order; ++
i)
310 ss << (
i == 0 ?
'[' :
' ') << ngram.
words[
i];
311 KALDI_ERR <<
"In line " << LineNumber() <<
": " 312 << cur_order <<
"-gram " << ss.str() <<
"] does not have " 313 <<
"a parent model " << cur_order <<
"-gram.";
315 if (cur_order != ngram_order_ || ngram_order_ == 1) {
318 hist_iter->second->AddChild(last_word, lm_state);
322 hist_iter->second->AddChild(last_word, ngram.
logprob);
326 num_words_ = std::max(num_words_, last_word + 1);
359 std::vector<std::pair<std::vector<int32>*,
LmState*> > sorted_vec;
360 unordered_map<std::vector<int32>,
362 for (iter = seq_to_state_.begin(); iter != seq_to_state_.end(); ++iter) {
363 if (iter->second->MemSize() > 0) {
364 sorted_vec.push_back(
365 std::make_pair(
const_cast<std::vector<int32>*
>(&(iter->first)),
370 std::sort(sorted_vec.begin(), sorted_vec.end(),
374 for (
int32 i = 0;
i < sorted_vec.size(); ++
i) {
375 lm_states_size_ += sorted_vec[
i].second->MemSize();
377 sorted_vec[
i].second->SetMyAddress(0);
379 sorted_vec[
i].second->SetMyAddress(sorted_vec[
i - 1].second->MyAddress()
380 + sorted_vec[
i - 1].second->MemSize());
386 int64 lm_states_index = 0;
388 lm_states_ =
new int32[lm_states_size_];
389 }
catch(
const std::exception &e) {
394 unigram_states_ =
new int32*[num_words_];
395 std::vector<int32*> overflow_buffer_vec;
396 for (
int32 i = 0;
i < num_words_; ++
i) {
397 unigram_states_[
i] = NULL;
399 for (
int32 i = 0;
i < sorted_vec.size(); ++
i) {
401 int32* parent_address = lm_states_ + lm_states_index;
405 lm_states_[lm_states_index++] = logprob_f.
i;
408 Int32AndFloat backoff_logprob_f(sorted_vec[
i].second->BackoffLogprob());
409 lm_states_[lm_states_index++] = backoff_logprob_f.
i;
412 lm_states_[lm_states_index++] = sorted_vec[
i].second->NumChildren();
419 sorted_vec[
i].second->SortChildren();
420 for (
int32 j = 0;
j < sorted_vec[
i].second->NumChildren(); ++
j) {
422 if (sorted_vec[
i].second->IsChildFinalOrder() ||
423 sorted_vec[
i].second->GetChild(
j).second.state->MemSize() == 0) {
428 if (sorted_vec[
i].second->IsChildFinalOrder()) {
429 child_logprob_f.
f = sorted_vec[
i].second->GetChild(
j).second.prob;
432 sorted_vec[
i].second->GetChild(
j).second.state->Logprob();
434 child_info = child_logprob_f.
i;
439 sorted_vec[
i].second->GetChild(
j).second.state->MyAddress()
440 - sorted_vec[
i].second->MyAddress();
442 if (offset <= max_address_offset_) {
444 child_info = offset * 2;
449 int32* abs_address = parent_address + offset;
450 overflow_buffer_vec.push_back(abs_address);
451 int32 overflow_buffer_index = overflow_buffer_vec.size() - 1;
452 child_info = overflow_buffer_index * 2;
458 lm_states_[lm_states_index++] = sorted_vec[
i].second->GetChild(
j).first;
460 lm_states_[lm_states_index++] = child_info;
466 if (sorted_vec[
i].second->IsUnigram()) {
468 unigram_states_[(*sorted_vec[
i].first)[0]] = parent_address;
474 overflow_buffer_size_ = overflow_buffer_vec.size();
475 overflow_buffer_ =
new int32*[overflow_buffer_size_];
476 for (
int32 i = 0;
i < overflow_buffer_size_; ++
i) {
477 overflow_buffer_[
i] = overflow_buffer_vec[
i];
485 KALDI_ERR <<
"text-mode writing is not implemented for ConstArpaLmBuilder.";
491 Options().bos_symbol, Options().eos_symbol, Options().unk_symbol,
492 ngram_order_, num_words_, overflow_buffer_size_, lm_states_size_,
493 unigram_states_, overflow_buffer_, lm_states_);
494 const_arpa_lm.
Write(os, binary);
500 KALDI_ERR <<
"text-mode writing is not implemented for ConstArpaLm.";
516 os.write(reinterpret_cast<char *>(lm_states_),
517 sizeof(
int32) * lm_states_size_);
519 KALDI_ERR <<
"ConstArpaLm <LmStates> section writing failed.";
527 int64* tmp_unigram_address =
new int64[num_words_];
528 for (
int32 i = 0;
i < num_words_; ++
i) {
535 tmp_unigram_address[
i] = (unigram_states_[
i] == NULL) ? 0 :
536 unigram_states_[
i] - lm_states_ + 1;
538 os.write(reinterpret_cast<char *>(tmp_unigram_address),
539 sizeof(int64) * num_words_);
541 KALDI_ERR <<
"ConstArpaLm <LmUnigram> section writing failed.";
543 delete[] tmp_unigram_address;
544 tmp_unigram_address = NULL;
551 int64* tmp_overflow_address =
new int64[overflow_buffer_size_];
552 for (
int32 i = 0;
i < overflow_buffer_size_; ++
i) {
559 tmp_overflow_address[
i] = (overflow_buffer_[
i] == NULL) ? 0 :
560 overflow_buffer_[
i] - lm_states_ + 1;
562 os.write(reinterpret_cast<char *>(tmp_overflow_address),
563 sizeof(int64) * overflow_buffer_size_);
565 KALDI_ERR <<
"ConstArpaLm <LmOverflow> section writing failed.";
567 delete[] tmp_overflow_address;
568 tmp_overflow_address = NULL;
576 KALDI_ERR <<
"text-mode reading is not implemented for ConstArpaLm.";
579 int first_char = is.peek();
580 if (first_char == 4) {
581 ReadInternalOldFormat(is, binary);
583 ReadInternal(is, binary);
590 KALDI_ERR <<
"text-mode reading is not implemented for ConstArpaLm.";
606 lm_states_ =
new int32[lm_states_size_];
607 is.read(reinterpret_cast<char *>(lm_states_),
608 sizeof(
int32) * lm_states_size_);
610 KALDI_ERR <<
"ConstArpaLm <LmStates> section reading failed.";
618 unigram_states_ =
new int32*[num_words_];
619 int64* tmp_unigram_address =
new int64[num_words_];
620 is.read(reinterpret_cast<char *>(tmp_unigram_address),
621 sizeof(int64) * num_words_);
623 KALDI_ERR <<
"ConstArpaLm <LmUnigram> section reading failed.";
625 for (
int32 i = 0;
i < num_words_; ++
i) {
627 unigram_states_[
i] = (tmp_unigram_address[
i] == 0) ? NULL
628 : lm_states_ + tmp_unigram_address[
i] - 1;
630 delete[] tmp_unigram_address;
631 tmp_unigram_address = NULL;
638 overflow_buffer_ =
new int32*[overflow_buffer_size_];
639 int64* tmp_overflow_address =
new int64[overflow_buffer_size_];
640 is.read(reinterpret_cast<char *>(tmp_overflow_address),
641 sizeof(int64) * overflow_buffer_size_);
643 KALDI_ERR <<
"ConstArpaLm <LmOverflow> section reading failed.";
645 for (
int32 i = 0;
i < overflow_buffer_size_; ++
i) {
647 overflow_buffer_[
i] = (tmp_overflow_address[
i] == 0) ? NULL
648 : lm_states_ + tmp_overflow_address[
i] - 1;
650 delete[] tmp_overflow_address;
651 tmp_overflow_address = NULL;
656 KALDI_ASSERT(bos_symbol_ < num_words_ && bos_symbol_ > 0);
657 KALDI_ASSERT(eos_symbol_ < num_words_ && eos_symbol_ > 0);
659 (unk_symbol_ > 0 || unk_symbol_ == -1));
660 lm_states_end_ = lm_states_ + lm_states_size_ - 1;
661 memory_assigned_ =
true;
668 KALDI_ERR <<
"text-mode reading is not implemented for ConstArpaLm.";
680 int32 lm_states_size_int32;
682 lm_states_size_ =
static_cast<int64
>(lm_states_size_int32);
683 lm_states_ =
new int32[lm_states_size_];
684 for (int64
i = 0;
i < lm_states_size_; ++
i) {
691 unigram_states_ =
new int32*[num_words_];
692 for (
int32 i = 0;
i < num_words_; ++
i) {
697 (tmp_address == 0) ? NULL : lm_states_ + tmp_address - 1;
703 overflow_buffer_ =
new int32*[overflow_buffer_size_];
704 for (
int32 i = 0;
i < overflow_buffer_size_; ++
i) {
708 overflow_buffer_[
i] =
709 (tmp_address == 0) ? NULL : lm_states_ + tmp_address - 1;
712 KALDI_ASSERT(bos_symbol_ < num_words_ && bos_symbol_ > 0);
713 KALDI_ASSERT(eos_symbol_ < num_words_ && eos_symbol_ > 0);
715 (unk_symbol_ > 0 || unk_symbol_ == -1));
716 lm_states_end_ = lm_states_ + lm_states_size_ - 1;
717 memory_assigned_ =
true;
724 if (hist.size() == 0) {
729 int32* lm_state = GetLmState(hist);
730 if (lm_state == NULL) {
739 if (*(lm_state + 2) > 0) {
749 const std::vector<int32>& hist)
const {
754 std::vector<int32> mapped_hist(hist);
755 while (mapped_hist.size() >= ngram_order_) {
756 mapped_hist.erase(mapped_hist.begin(), mapped_hist.begin() + 1);
764 int32 mapped_word = word;
765 if (unk_symbol_ != -1) {
767 if (mapped_word >= num_words_ || unigram_states_[mapped_word] == NULL) {
768 mapped_word = unk_symbol_;
770 for (
int32 i = 0;
i < mapped_hist.size(); ++
i) {
772 if (mapped_hist[
i] >= num_words_ ||
773 unigram_states_[mapped_hist[
i]] == NULL) {
774 mapped_hist[
i] = unk_symbol_;
780 return GetNgramLogprobRecurse(mapped_word, mapped_hist);
784 const int32 word,
const std::vector<int32>& hist)
const {
789 if (hist.size() == 0) {
790 if (word >= num_words_ || unigram_states_[word] == NULL) {
794 return std::numeric_limits<float>::min();
805 if ((state = GetLmState(hist)) != NULL) {
807 int32* child_lm_state = NULL;
808 if (GetChildInfo(word, state, &child_info)) {
809 DecodeChildInfo(child_info, state, &child_lm_state, &logprob);
813 backoff_logprob = backoff_logprob_i.
f;
816 std::vector<int32> new_hist(hist);
817 new_hist.erase(new_hist.begin(), new_hist.begin() + 1);
818 return backoff_logprob + GetNgramLogprobRecurse(word, new_hist);
825 if (seq.size() == 0)
return NULL;
829 if (seq[0] >= num_words_ || unigram_states_[seq[0]] == NULL)
return NULL;
830 int32* parent = unigram_states_[seq[0]];
833 int32* child_lm_state = NULL;
835 for (
int32 i = 1;
i < seq.size(); ++
i) {
836 if (!GetChildInfo(seq[
i], parent, &child_info)) {
839 DecodeChildInfo(child_info, parent, &child_lm_state, &logprob);
840 if (child_lm_state == NULL) {
843 parent = child_lm_state;
858 int32 num_children = *(parent + 2);
859 KALDI_ASSERT(parent + 2 + 2 * num_children <= lm_states_end_);
861 if (num_children == 0)
return false;
864 int32 start_index = 1;
865 int32 end_index = num_children;
866 while (start_index <= end_index) {
867 int32 mid_index = round((start_index + end_index) / 2);
868 int32 mid_word = *(parent + 1 + 2 * mid_index);
869 if (mid_word == word) {
870 *child_info = *(parent + 2 + 2 * mid_index);
872 }
else if (mid_word < word) {
873 start_index = mid_index + 1;
875 end_index = mid_index - 1;
884 int32** child_lm_state,
889 if (child_info % 2 == 0) {
891 *child_lm_state = NULL;
893 *logprob = logprob_i.
f;
895 int32 child_offset = child_info / 2;
896 if (child_offset > 0) {
897 *child_lm_state = parent + child_offset;
899 *logprob = logprob_i.
f;
902 *child_lm_state = overflow_buffer_[-child_offset];
904 *logprob = logprob_i.
f;
912 const std::vector<int32>& seq,
913 std::vector<ArpaLine> *output)
const {
914 if (lm_state == NULL)
return;
921 arpa_line.
words = seq;
926 output->push_back(arpa_line);
929 int32 num_children = *(lm_state + 2);
930 KALDI_ASSERT(lm_state + 2 + 2 * num_children <= lm_states_end_);
931 for (
int32 i = 0;
i < num_children; ++
i) {
932 std::vector<int32> new_seq(seq);
933 new_seq.push_back(*(lm_state + 3 + 2 *
i));
934 int32 child_info = *(lm_state + 4 + 2 *
i);
936 int32* child_lm_state = NULL;
937 DecodeChildInfo(child_info, lm_state, &child_lm_state, &logprob);
939 if (child_lm_state == NULL) {
942 child_arpa_line.
words = new_seq;
945 output->push_back(child_arpa_line);
947 WriteArpaRecurse(child_lm_state, new_seq, output);
955 std::vector<ArpaLine> tmp_output;
956 for (
int32 i = 0;
i < num_words_; ++
i) {
957 if (unigram_states_[
i] != NULL) {
958 std::vector<int32> seq(1,
i);
959 WriteArpaRecurse(unigram_states_[
i], seq, &tmp_output);
964 std::sort(tmp_output.begin(), tmp_output.end());
965 std::vector<int32> ngram_count(1, 0);
966 for (
int32 i = 0;
i < tmp_output.size(); ++
i) {
967 if (tmp_output[
i].
words.size() >= ngram_count.size()) {
968 ngram_count.resize(tmp_output[
i].
words.size() + 1);
969 ngram_count[tmp_output[
i].words.size()] = 1;
971 ngram_count[tmp_output[
i].words.size()] += 1;
977 os <<
"\\data\\" << std::endl;
978 for (
int32 i = 1;
i < ngram_count.size(); ++
i) {
979 os <<
"ngram " <<
i <<
"=" << ngram_count[
i] << std::endl;
983 int32 current_order = 0;
984 for (
int32 i = 0;
i < tmp_output.size(); ++
i) {
986 if (tmp_output[
i].
words.size() != current_order) {
987 current_order = tmp_output[
i].words.size();
989 os <<
"\\" << current_order <<
"-grams:" << std::endl;
993 os << tmp_output[
i].logprob <<
'\t';
996 for (
int32 j = 0;
j < tmp_output[
i].words.size(); ++
j) {
997 os << tmp_output[
i].words[
j];
998 if (
j != tmp_output[
i].
words.size() - 1) {
1005 os <<
'\t' << tmp_output[
i].backoff_logprob;
1010 os << std::endl <<
"\\end\\" << std::endl;
1037 if (logprob == std::numeric_limits<float>::min()) {
1043 wseq.push_back(ilabel);
1046 wseq.erase(wseq.begin(), wseq.begin() + 1);
1050 wseq.erase(wseq.begin(), wseq.begin() + 1);
1053 std::pair<const std::vector<Label>,
StateId> wseq_state_pair(
1058 typedef MapType::iterator IterType;
1059 std::pair<IterType, bool> result =
wseq_to_state_.insert(wseq_state_pair);
1062 if (result.second ==
true)
1066 oarc->ilabel = ilabel;
1067 oarc->olabel = ilabel;
1068 oarc->nextstate = result.first->second;
1069 oarc->weight =
Weight(-logprob);
1075 const std::string& arpa_rxfilename,
1076 const std::string& const_arpa_wxfilename) {
1078 KALDI_LOG <<
"Reading " << arpa_rxfilename;
1079 Input ki(arpa_rxfilename);
ArpaFileParser is an abstract base class for ARPA LM file conversion.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
fst::StdArc::Weight Weight
bool is_child_final_order_
std::vector< std::vector< Label > > state_to_wseq_
void Write(std::ostream &os, bool binary) const
A hashing function-object for vectors.
void Read(std::istream &is, bool binary)
unordered_map< std::vector< int32 >, LmState *, VectorHasher< int32 > > seq_to_state_
bool operator()(const std::pair< std::vector< int32 > *, LmState *> &lhs, const std::pair< std::vector< int32 > *, LmState *> &rhs) const
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
virtual void ReadComplete()
Override function called after the last n-gram has been consumed.
Options that control ArpaFileParser.
float logprob
Log-prob of the n-gram.
void ReadInternalOldFormat(std::istream &is, bool binary)
LmState(const bool is_unigram, const bool is_child_final_order, const float logprob, const float backoff_logprob)
void SetMaxAddressOffset(const int32 max_address_offset)
virtual bool GetArc(StateId s, Label ilabel, fst::StdArc *oarc)
ConstArpaLmDeterministicFst(const ConstArpaLm &lm)
void AddChild(const int32 word, LmState *child_state)
void Write(std::ostream &os, bool binary) const
int32 overflow_buffer_size_
virtual void HeaderAvailable()
Override function called to signal that ARPA header with the expected number of n-grams has been read...
float backoff
log-backoff weight of the n-gram.
int32 ** overflow_buffer_
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
void WriteArpaRecurse(int32 *lm_state, const std::vector< int32 > &seq, std::vector< ArpaLine > *output) const
void DecodeChildInfo(const int32 child_info, int32 *parent, int32 **child_lm_state, float *logprob) const
std::vector< int32 > words
Symbols in left to right order.
void SetMyAddress(const int64 address)
std::pair< int32, union ChildType > GetChild(const int32 index)
void Read(std::istream &is)
Read ARPA LM file from a stream.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
int32 * GetLmState(const std::vector< int32 > &seq) const
fst::StdArc::Weight Weight
std::vector< std::pair< int32, union ChildType > > children_
void AddChild(const int32 word, const float child_prob)
fst::StdArc::StateId StateId
#define KALDI_ASSERT(cond)
float GetNgramLogprob(const int32 word, const std::vector< int32 > &hist) const
bool operator<(const ArpaLine &other) const
bool operator()(const std::pair< int32, union ChildType > &lhs, const std::pair< int32, union ChildType > &rhs) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
float GetNgramLogprobRecurse(const int32 word, const std::vector< int32 > &hist) const
bool BuildConstArpaLm(const ArpaParseOptions &options, const std::string &arpa_rxfilename, const std::string &const_arpa_wxfilename)
virtual void ConsumeNGram(const NGram &ngram)
Pure override that must be implemented to process current n-gram.
virtual Weight Final(StateId s)
std::vector< int32 > words
int32 NumChildren() const
A parsed n-gram from ARPA LM file.
int32 max_address_offset_
ConstArpaLmBuilder(ArpaParseOptions options)
bool IsChildFinalOrder() const
void ReadInternal(std::istream &is, bool binary)
bool HistoryStateExists(const std::vector< int32 > &hist) const
float BackoffLogprob() const
void WriteArpa(std::ostream &os) const
bool GetChildInfo(const int32 word, int32 *parent, int32 *child_info) const