30 const std::string &unk_prob_rspecifier,
31 const std::string &word_symbol_table_rxfilename,
32 const std::string &rnnlm_rxfilename) {
40 fst::SymbolTable *word_symbols = NULL;
42 fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) {
43 KALDI_ERR <<
"Could not read symbol table from file " 44 << word_symbol_table_rxfilename;
50 KALDI_ERR <<
"Could not find word for integer " <<
i <<
"in the word " 51 <<
"symbol table, mismatched symbol table or you have discontinuous " 52 <<
"integers in your symbol table?";
60 int32 word,
const std::vector<int32> &wseq,
61 const std::vector<float> &context_in,
62 std::vector<float> *context_out) {
64 std::vector<std::string> wseq_symbols(wseq.size());
65 for (
int32 i = 0;
i < wseq_symbols.size(); ++
i) {
71 context_in, context_out);
77 max_ngram_order_ = max_ngram_order;
81 std::vector<Label> bos;
83 state_to_wseq_.push_back(bos);
84 state_to_context_.push_back(bos_context);
85 wseq_to_state_[bos] = 0;
91 KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
93 std::vector<Label> wseq = state_to_wseq_[s];
95 state_to_context_[s], NULL);
101 KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
103 std::vector<Label> wseq = state_to_wseq_[s];
104 std::vector<float> new_context(
rnnlm_->GetHiddenLayerSize());
106 state_to_context_[s], &new_context);
108 wseq.push_back(ilabel);
109 if (max_ngram_order_ > 0) {
110 while (wseq.size() >= max_ngram_order_) {
112 wseq.erase(wseq.begin(), wseq.begin() + 1);
116 std::pair<const std::vector<Label>,
StateId> wseq_state_pair(
117 wseq, static_cast<Label>(state_to_wseq_.size()));
121 typedef MapType::iterator IterType;
122 std::pair<IterType, bool> result = wseq_to_state_.insert(wseq_state_pair);
126 if (result.second ==
true) {
127 state_to_wseq_.push_back(wseq);
128 state_to_context_.push_back(new_context);
132 oarc->ilabel = ilabel;
133 oarc->olabel = ilabel;
134 oarc->nextstate = result.first->second;
135 oarc->weight =
Weight(-logprob);
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
fst::StdArc::StateId StateId
void setUnkSym(const std::string &unk)
int32 GetHiddenLayerSize() const
BaseFloat GetLogProb(int32 word, const std::vector< int32 > &wseq, const std::vector< float > &context_in, std::vector< float > *context_out)
float computeConditionalLogprob(std::string current_word, const std::vector< std::string > &history_words, const std::vector< float > &context_in, std::vector< float > *context_out)
std::vector< std::string > label_to_word_
KaldiRnnlmWrapper(const KaldiRnnlmWrapperOpts &opts, const std::string &unk_prob_rspecifier, const std::string &word_symbol_table_rxfilename, const std::string &rnnlm_rxfilename)
void setRnnLMFile(const std::string &str)
virtual Weight Final(StateId s)
fst::StdArc::Weight Weight
void setRandSeed(int newSeed)
RnnlmDeterministicFst(int32 max_ngram_order, KaldiRnnlmWrapper *rnnlm)
void setUnkPenalty(const std::string &filename)
#define KALDI_ASSERT(cond)
virtual bool GetArc(StateId s, Label ilabel, fst::StdArc *oarc)