RnnlmDeterministicFst Class Reference

#include <kaldi-rnnlm.h>

Inheritance diagram for RnnlmDeterministicFst:
Collaboration diagram for RnnlmDeterministicFst:

Public Types

typedef fst::StdArc::Weight Weight
 
typedef fst::StdArc::StateId StateId
 
typedef fst::StdArc::Label Label
 
- Public Types inherited from DeterministicOnDemandFst< fst::StdArc >
typedef fst::StdArc ::StateId StateId
 
typedef fst::StdArc ::Weight Weight
 
typedef fst::StdArc ::Label Label
 

Public Member Functions

 RnnlmDeterministicFst (int32 max_ngram_order, KaldiRnnlmWrapper *rnnlm)
 
virtual StateId Start ()
 
virtual Weight Final (StateId s)
 
virtual bool GetArc (StateId s, Label ilabel, fst::StdArc *oarc)
 
- Public Member Functions inherited from DeterministicOnDemandFst< fst::StdArc >
virtual Weight Final (StateId s)=0
 
virtual bool GetArc (StateId s, Label ilabel, fst::StdArc *oarc)=0
 Note: ilabel must not be epsilon. More...
 
virtual ~DeterministicOnDemandFst ()
 

Private Types

typedef unordered_map< std::vector< Label >, StateId, VectorHasher< Label > > MapType
 

Private Attributes

StateId start_state_
 
MapType wseq_to_state_
 
std::vector< std::vector< Label > > state_to_wseq_
 
KaldiRnnlmWrapperrnnlm_
 
int32 max_ngram_order_
 
std::vector< std::vector< float > > state_to_context_
 

Detailed Description

Definition at line 70 of file kaldi-rnnlm.h.

Member Typedef Documentation

◆ Label

typedef fst::StdArc::Label Label

Definition at line 75 of file kaldi-rnnlm.h.

◆ MapType

typedef unordered_map<std::vector<Label>, StateId, VectorHasher<Label> > MapType
private

Definition at line 92 of file kaldi-rnnlm.h.

◆ StateId

typedef fst::StdArc::StateId StateId

Definition at line 74 of file kaldi-rnnlm.h.

◆ Weight

typedef fst::StdArc::Weight Weight

Definition at line 73 of file kaldi-rnnlm.h.

Constructor & Destructor Documentation

◆ RnnlmDeterministicFst()

RnnlmDeterministicFst ( int32  max_ngram_order,
KaldiRnnlmWrapper rnnlm 
)

Definition at line 74 of file kaldi-rnnlm.cc.

References KaldiRnnlmWrapper::GetHiddenLayerSize(), KALDI_ASSERT, and KaldiRnnlmWrapper::rnnlm_.

75  {
76  KALDI_ASSERT(rnnlm != NULL);
77  max_ngram_order_ = max_ngram_order;
78  rnnlm_ = rnnlm;
79 
80  // Uses empty history for <s>.
81  std::vector<Label> bos;
82  std::vector<float> bos_context(rnnlm->GetHiddenLayerSize(), 1.0);
83  state_to_wseq_.push_back(bos);
84  state_to_context_.push_back(bos_context);
85  wseq_to_state_[bos] = 0;
86  start_state_ = 0;
87 }
std::vector< std::vector< float > > state_to_context_
Definition: kaldi-rnnlm.h:99
std::vector< std::vector< Label > > state_to_wseq_
Definition: kaldi-rnnlm.h:95
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
KaldiRnnlmWrapper * rnnlm_
Definition: kaldi-rnnlm.h:97

Member Function Documentation

◆ Final()

fst::StdArc::Weight Final ( StateId  s)
virtual

Definition at line 89 of file kaldi-rnnlm.cc.

References KALDI_ASSERT, logprob, and KaldiRnnlmWrapper::rnnlm_.

89  {
90  // At this point, we should have created the state.
91  KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
92 
93  std::vector<Label> wseq = state_to_wseq_[s];
95  state_to_context_[s], NULL);
96  return Weight(-logprob);
97 }
float logprob
BaseFloat GetLogProb(int32 word, const std::vector< int32 > &wseq, const std::vector< float > &context_in, std::vector< float > *context_out)
Definition: kaldi-rnnlm.cc:59
std::vector< std::vector< float > > state_to_context_
Definition: kaldi-rnnlm.h:99
std::vector< std::vector< Label > > state_to_wseq_
Definition: kaldi-rnnlm.h:95
float BaseFloat
Definition: kaldi-types.h:29
fst::StdArc::Weight Weight
Definition: kaldi-rnnlm.h:73
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetEos() const
Definition: kaldi-rnnlm.h:56
KaldiRnnlmWrapper * rnnlm_
Definition: kaldi-rnnlm.h:97

◆ GetArc()

bool GetArc ( StateId  s,
Label  ilabel,
fst::StdArc oarc 
)
virtual

Definition at line 99 of file kaldi-rnnlm.cc.

References KALDI_ASSERT, logprob, and KaldiRnnlmWrapper::rnnlm_.

99  {
100  // At this point, we should have created the state.
101  KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
102 
103  std::vector<Label> wseq = state_to_wseq_[s];
104  std::vector<float> new_context(rnnlm_->GetHiddenLayerSize());
105  BaseFloat logprob = rnnlm_->GetLogProb(ilabel, wseq,
106  state_to_context_[s], &new_context);
107 
108  wseq.push_back(ilabel);
109  if (max_ngram_order_ > 0) {
110  while (wseq.size() >= max_ngram_order_) {
111  // History state has at most <max_ngram_order_> - 1 words in the state.
112  wseq.erase(wseq.begin(), wseq.begin() + 1);
113  }
114  }
115 
116  std::pair<const std::vector<Label>, StateId> wseq_state_pair(
117  wseq, static_cast<Label>(state_to_wseq_.size()));
118 
119  // Attemps to insert the current <lseq_state_pair>. If the pair already exists
120  // then it returns false.
121  typedef MapType::iterator IterType;
122  std::pair<IterType, bool> result = wseq_to_state_.insert(wseq_state_pair);
123 
124  // If the pair was just inserted, then also add it to <state_to_wseq_> and
125  // <state_to_context_>.
126  if (result.second == true) {
127  state_to_wseq_.push_back(wseq);
128  state_to_context_.push_back(new_context);
129  }
130 
131  // Creates the arc.
132  oarc->ilabel = ilabel;
133  oarc->olabel = ilabel;
134  oarc->nextstate = result.first->second;
135  oarc->weight = Weight(-logprob);
136 
137  return true;
138 }
fst::StdArc::StateId StateId
Definition: kaldi-rnnlm.h:74
float logprob
int32 GetHiddenLayerSize() const
Definition: kaldi-rnnlm.h:54
BaseFloat GetLogProb(int32 word, const std::vector< int32 > &wseq, const std::vector< float > &context_in, std::vector< float > *context_out)
Definition: kaldi-rnnlm.cc:59
std::vector< std::vector< float > > state_to_context_
Definition: kaldi-rnnlm.h:99
std::vector< std::vector< Label > > state_to_wseq_
Definition: kaldi-rnnlm.h:95
float BaseFloat
Definition: kaldi-types.h:29
fst::StdArc::Weight Weight
Definition: kaldi-rnnlm.h:73
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
KaldiRnnlmWrapper * rnnlm_
Definition: kaldi-rnnlm.h:97

◆ Start()

virtual StateId Start ( )
inlinevirtual

Implements DeterministicOnDemandFst< fst::StdArc >.

Definition at line 82 of file kaldi-rnnlm.h.

82 { return start_state_; }

Member Data Documentation

◆ max_ngram_order_

int32 max_ngram_order_
private

Definition at line 98 of file kaldi-rnnlm.h.

◆ rnnlm_

KaldiRnnlmWrapper* rnnlm_
private

Definition at line 97 of file kaldi-rnnlm.h.

◆ start_state_

StateId start_state_
private

Definition at line 93 of file kaldi-rnnlm.h.

◆ state_to_context_

std::vector<std::vector<float> > state_to_context_
private

Definition at line 99 of file kaldi-rnnlm.h.

◆ state_to_wseq_

std::vector<std::vector<Label> > state_to_wseq_
private

Definition at line 95 of file kaldi-rnnlm.h.

◆ wseq_to_state_

MapType wseq_to_state_
private

Definition at line 94 of file kaldi-rnnlm.h.


The documentation for this class was generated from the following files: