kaldi-rnnlm.cc
Go to the documentation of this file.
1 // lm/kaldi-rnnlm.cc
2 
3 // Copyright 2015 Guoguo Chen
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <utility>
21 
22 #include "lm/kaldi-rnnlm.h"
23 #include "util/stl-utils.h"
24 #include "util/text-utils.h"
25 
26 namespace kaldi {
27 
29  const KaldiRnnlmWrapperOpts &opts,
30  const std::string &unk_prob_rspecifier,
31  const std::string &word_symbol_table_rxfilename,
32  const std::string &rnnlm_rxfilename) {
33  rnnlm_.setRnnLMFile(rnnlm_rxfilename);
36  rnnlm_.setUnkPenalty(unk_prob_rspecifier);
38 
39  // Reads symbol table.
40  fst::SymbolTable *word_symbols = NULL;
41  if (!(word_symbols =
42  fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) {
43  KALDI_ERR << "Could not read symbol table from file "
44  << word_symbol_table_rxfilename;
45  }
46  label_to_word_.resize(word_symbols->NumSymbols() + 1);
47  for (int32 i = 0; i < label_to_word_.size() - 1; ++i) {
48  label_to_word_[i] = word_symbols->Find(i);
49  if (label_to_word_[i] == "") {
50  KALDI_ERR << "Could not find word for integer " << i << "in the word "
51  << "symbol table, mismatched symbol table or you have discontinuous "
52  << "integers in your symbol table?";
53  }
54  }
55  label_to_word_[label_to_word_.size() - 1] = opts.eos_symbol;
56  eos_ = label_to_word_.size() - 1;
57 }
58 
60  int32 word, const std::vector<int32> &wseq,
61  const std::vector<float> &context_in,
62  std::vector<float> *context_out) {
63 
64  std::vector<std::string> wseq_symbols(wseq.size());
65  for (int32 i = 0; i < wseq_symbols.size(); ++i) {
66  KALDI_ASSERT(wseq[i] < label_to_word_.size());
67  wseq_symbols[i] = label_to_word_[wseq[i]];
68  }
69 
70  return rnnlm_.computeConditionalLogprob(label_to_word_[word], wseq_symbols,
71  context_in, context_out);
72 }
73 
76  KALDI_ASSERT(rnnlm != NULL);
77  max_ngram_order_ = max_ngram_order;
78  rnnlm_ = rnnlm;
79 
80  // Uses empty history for <s>.
81  std::vector<Label> bos;
82  std::vector<float> bos_context(rnnlm->GetHiddenLayerSize(), 1.0);
83  state_to_wseq_.push_back(bos);
84  state_to_context_.push_back(bos_context);
85  wseq_to_state_[bos] = 0;
86  start_state_ = 0;
87 }
88 
90  // At this point, we should have created the state.
91  KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
92 
93  std::vector<Label> wseq = state_to_wseq_[s];
94  BaseFloat logprob = rnnlm_->GetLogProb(rnnlm_->GetEos(), wseq,
95  state_to_context_[s], NULL);
96  return Weight(-logprob);
97 }
98 
100  // At this point, we should have created the state.
101  KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
102 
103  std::vector<Label> wseq = state_to_wseq_[s];
104  std::vector<float> new_context(rnnlm_->GetHiddenLayerSize());
105  BaseFloat logprob = rnnlm_->GetLogProb(ilabel, wseq,
106  state_to_context_[s], &new_context);
107 
108  wseq.push_back(ilabel);
109  if (max_ngram_order_ > 0) {
110  while (wseq.size() >= max_ngram_order_) {
111  // History state has at most <max_ngram_order_> - 1 words in the state.
112  wseq.erase(wseq.begin(), wseq.begin() + 1);
113  }
114  }
115 
116  std::pair<const std::vector<Label>, StateId> wseq_state_pair(
117  wseq, static_cast<Label>(state_to_wseq_.size()));
118 
119  // Attemps to insert the current <lseq_state_pair>. If the pair already exists
120  // then it returns false.
121  typedef MapType::iterator IterType;
122  std::pair<IterType, bool> result = wseq_to_state_.insert(wseq_state_pair);
123 
124  // If the pair was just inserted, then also add it to <state_to_wseq_> and
125  // <state_to_context_>.
126  if (result.second == true) {
127  state_to_wseq_.push_back(wseq);
128  state_to_context_.push_back(new_context);
129  }
130 
131  // Creates the arc.
132  oarc->ilabel = ilabel;
133  oarc->olabel = ilabel;
134  oarc->nextstate = result.first->second;
135  oarc->weight = Weight(-logprob);
136 
137  return true;
138 }
139 
140 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::StdArc::StateId StateId
Definition: kaldi-rnnlm.h:74
rnnlm::CRnnLM rnnlm_
Definition: kaldi-rnnlm.h:63
void setUnkSym(const std::string &unk)
fst::StdArc StdArc
float logprob
int32 GetHiddenLayerSize() const
Definition: kaldi-rnnlm.h:54
kaldi::int32 int32
BaseFloat GetLogProb(int32 word, const std::vector< int32 > &wseq, const std::vector< float > &context_in, std::vector< float > *context_out)
Definition: kaldi-rnnlm.cc:59
float computeConditionalLogprob(std::string current_word, const std::vector< std::string > &history_words, const std::vector< float > &context_in, std::vector< float > *context_out)
fst::StdArc::Label Label
Definition: kaldi-rnnlm.h:75
std::vector< std::string > label_to_word_
Definition: kaldi-rnnlm.h:64
KaldiRnnlmWrapper(const KaldiRnnlmWrapperOpts &opts, const std::string &unk_prob_rspecifier, const std::string &word_symbol_table_rxfilename, const std::string &rnnlm_rxfilename)
Definition: kaldi-rnnlm.cc:28
#define KALDI_ERR
Definition: kaldi-error.h:147
void setRnnLMFile(const std::string &str)
virtual Weight Final(StateId s)
Definition: kaldi-rnnlm.cc:89
fst::StdArc::Weight Weight
void setRandSeed(int newSeed)
RnnlmDeterministicFst(int32 max_ngram_order, KaldiRnnlmWrapper *rnnlm)
Definition: kaldi-rnnlm.cc:74
void setUnkPenalty(const std::string &filename)
Arc::Weight Weight
Definition: kws-search.cc:31
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual bool GetArc(StateId s, Label ilabel, fst::StdArc *oarc)
Definition: kaldi-rnnlm.cc:99