transition-model.h
Go to the documentation of this file.
1 // hmm/transition-model.h
2 
3 // Copyright 2009-2012 Microsoft Corporation
4 // Johns Hopkins University (author: Guoguo Chen)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_HMM_TRANSITION_MODEL_H_
22 #define KALDI_HMM_TRANSITION_MODEL_H_
23 
24 #include "base/kaldi-common.h"
25 #include "util/const-integer-set.h"
26 #include "fst/fst-decl.h" // forward declarations.
27 #include "hmm/hmm-topology.h"
28 #include "itf/options-itf.h"
29 #include "itf/context-dep-itf.h"
30 #include "matrix/kaldi-vector.h"
31 
32 namespace kaldi {
33 
36 
37 // The class TransitionModel is a repository for the transition probabilities.
38 // It also handles certain integer mappings.
39 // The basic model is as follows. Each phone has a HMM topology defined in
40 // hmm-topology.h. Each HMM-state of each of these phones has a number of
41 // transitions (and final-probs) out of it. Each HMM-state defined in the
42 // HmmTopology class has an associated "pdf_class". This gets replaced with
43 // an actual pdf-id via the tree. The transition model associates the
44 // transition probs with the (phone, HMM-state, pdf-id). We associate with
45 // each such triple a transition-state. Each
46 // transition-state has a number of associated probabilities to estimate;
47 // this depends on the number of transitions/final-probs in the topology for
48 // that (phone, HMM-state). Each probability has an associated transition-index.
49 // We associate with each (transition-state, transition-index) a unique transition-id.
50 // Each individual probability estimated by the transition-model is associated with a
51 // transition-id.
52 //
53 // List of the various types of quantity referred to here and what they mean:
54 // phone: a phone index (1, 2, 3 ...)
55 // HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h)
56 // pdf-id: a number output by the Compute function of ContextDependency (it
57 // indexes pdf's, either forward or self-loop). Zero-based.
58 // transition-state: the states for which we estimate transition probabilities for transitions
59 // out of them. In some topologies, will map one-to-one with pdf-ids.
60 // One-based, since it appears on FSTs.
61 // transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the
62 // "transitions" vector in HmmTopology::HmmState. [if it is out of range,
63 // equal to transitions.size(), it refers to the final-prob.]
64 // Zero-based.
65 // transition-id: identifier of a unique parameter of the TransitionModel.
66 // Associated with a (transition-state, transition-index) pair.
67 // One-based, since it appears on FSTs.
68 //
69 // List of the possible mappings TransitionModel can do:
70 // (phone, HMM-state, forward-pdf-id, self-loop-pdf-id) -> transition-state
71 // (transition-state, transition-index) -> transition-id
72 // Reverse mappings:
73 // transition-id -> transition-state
74 // transition-id -> transition-index
75 // transition-state -> phone
76 // transition-state -> HMM-state
77 // transition-state -> forward-pdf-id
78 // transition-state -> self-loop-pdf-id
79 //
80 // The main things the TransitionModel object can do are:
81 // Get initialized (need ContextDependency and HmmTopology objects).
82 // Read/write.
83 // Update [given a vector of counts indexed by transition-id].
84 // Do the various integer mappings mentioned above.
85 // Get the probability (or log-probability) associated with a particular transition-id.
86 
87 
88 // Note: this was previously called TransitionUpdateConfig.
92  bool share_for_pdfs; // If true, share all transition parameters that have the same pdf.
94  BaseFloat mincount = 5.0,
95  bool share_for_pdfs = false):
96  floor(floor), mincount(mincount), share_for_pdfs(share_for_pdfs) {}
97 
98  void Register (OptionsItf *opts) {
99  opts->Register("transition-floor", &floor,
100  "Floor for transition probabilities");
101  opts->Register("transition-min-count", &mincount,
102  "Minimum count required to update transitions from a state");
103  opts->Register("share-for-pdfs", &share_for_pdfs,
104  "If true, share all transition parameters where the states "
105  "have the same pdf.");
106  }
107 };
108 
111  bool share_for_pdfs; // If true, share all transition parameters that have the same pdf.
112  MapTransitionUpdateConfig(): tau(5.0), share_for_pdfs(false) { }
113 
114  void Register (OptionsItf *opts) {
115  opts->Register("transition-tau", &tau, "Tau value for MAP estimation of transition "
116  "probabilities.");
117  opts->Register("share-for-pdfs", &share_for_pdfs,
118  "If true, share all transition parameters where the states "
119  "have the same pdf.");
120  }
121 };
122 
124 
125  public:
130  const HmmTopology &hmm_topo);
131 
132 
134  TransitionModel(): num_pdfs_(0) { }
135 
136  void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols.
137  void Write(std::ostream &os, bool binary) const;
138 
139 
141  const HmmTopology &GetTopo() const { return topo_; }
142 
145 
146  int32 TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const;
147  int32 PairToTransitionId(int32 trans_state, int32 trans_index) const;
148  int32 TransitionIdToTransitionState(int32 trans_id) const;
149  int32 TransitionIdToTransitionIndex(int32 trans_id) const;
150  int32 TransitionStateToPhone(int32 trans_state) const;
151  int32 TransitionStateToHmmState(int32 trans_state) const;
152  int32 TransitionStateToForwardPdfClass(int32 trans_state) const;
153  int32 TransitionStateToSelfLoopPdfClass(int32 trans_state) const;
154  int32 TransitionStateToForwardPdf(int32 trans_state) const;
155  int32 TransitionStateToSelfLoopPdf(int32 trans_state) const;
156  int32 SelfLoopOf(int32 trans_state) const; // returns the self-loop transition-id, or zero if
157  // this state doesn't have a self-loop.
158 
159  inline int32 TransitionIdToPdf(int32 trans_id) const;
160  // TransitionIdToPdfFast is as TransitionIdToPdf but skips an assertion
161  // (unless we're in paranoid mode).
162  inline int32 TransitionIdToPdfFast(int32 trans_id) const;
163 
164  int32 TransitionIdToPhone(int32 trans_id) const;
165  int32 TransitionIdToPdfClass(int32 trans_id) const;
166  int32 TransitionIdToHmmState(int32 trans_id) const;
167 
169 
170  bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state
171  // (which is bound to be nonemitting).
172  bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop.
173 
175  inline int32 NumTransitionIds() const { return id2state_.size()-1; }
176 
181  int32 NumTransitionIndices(int32 trans_state) const;
182 
184  int32 NumTransitionStates() const { return tuples_.size(); }
185 
186  // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one.
187  // In normal cases this should equal the number of pdfs in the system, but if you
188  // initialized this object with fewer than all the phones, and it happens that
189  // an unseen phone has the highest-numbered pdf, this might be different.
190  int32 NumPdfs() const { return num_pdfs_; }
191 
192  // This loops over the tuples and finds the highest phone index present. If
193  // the FST symbol table for the phones is created in the expected way, i.e.:
194  // starting from 1 (<eps> is 0) and numbered contiguously till the last phone,
195  // this will be the total number of phones.
196  int32 NumPhones() const;
197 
199  const std::vector<int32> &GetPhones() const { return topo_.GetPhones(); }
200 
201  // Transition-parameter-getting functions:
202  BaseFloat GetTransitionProb(int32 trans_id) const;
203  BaseFloat GetTransitionLogProb(int32 trans_id) const;
204 
205  // The following functions are more specialized functions for getting
206  // transition probabilities, that are provided for convenience.
207 
213  BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const;
214 
218  BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const;
219 
222  void MleUpdate(const Vector<double> &stats,
223  const MleTransitionUpdateConfig &cfg,
224  BaseFloat *objf_impr_out,
225  BaseFloat *count_out);
226 
229  void MapUpdate(const Vector<double> &stats,
230  const MapTransitionUpdateConfig &cfg,
231  BaseFloat *objf_impr_out,
232  BaseFloat *count_out);
233 
236  void Print(std::ostream &os,
237  const std::vector<std::string> &phone_names,
238  const Vector<double> *occs = NULL);
239 
240 
241  void InitStats(Vector<double> *stats) const { stats->Resize(NumTransitionIds()+1); }
242 
243  void Accumulate(BaseFloat prob, int32 trans_id, Vector<double> *stats) const {
244  KALDI_ASSERT(trans_id <= NumTransitionIds());
245  (*stats)(trans_id) += prob;
246  // This is trivial and doesn't require class members, but leaves us more open
247  // to design changes than doing it manually.
248  }
249 
252  bool Compatible(const TransitionModel &other) const;
253 
254  private:
255  void MleUpdateShared(const Vector<double> &stats,
256  const MleTransitionUpdateConfig &cfg,
257  BaseFloat *objf_impr_out, BaseFloat *count_out);
258  void MapUpdateShared(const Vector<double> &stats,
259  const MapTransitionUpdateConfig &cfg,
260  BaseFloat *objf_impr_out, BaseFloat *count_out);
261  void ComputeTuples(const ContextDependencyInterface &ctx_dep); // called from constructor. initializes tuples_.
262  void ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_dep);
263  void ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_dep);
264  void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_.
265  void ComputeDerivedOfProbs(); // computes quantities derived from log-probs (currently just
266  // non_self_loop_log_probs_; called whenever log-probs change.
267  void InitializeProbs(); // called from constructor.
268  void Check() const;
269  bool IsHmm() const;
270 
271  struct Tuple {
276  Tuple() { }
277  Tuple(int32 phone, int32 hmm_state, int32 forward_pdf, int32 self_loop_pdf):
278  phone(phone), hmm_state(hmm_state), forward_pdf(forward_pdf), self_loop_pdf(self_loop_pdf) { }
279  bool operator < (const Tuple &other) const {
280  if (phone < other.phone) return true;
281  else if (phone > other.phone) return false;
282  else if (hmm_state < other.hmm_state) return true;
283  else if (hmm_state > other.hmm_state) return false;
284  else if (forward_pdf < other.forward_pdf) return true;
285  else if (forward_pdf > other.forward_pdf) return false;
286  else return (self_loop_pdf < other.self_loop_pdf);
287  }
288  bool operator == (const Tuple &other) const {
289  return (phone == other.phone && hmm_state == other.hmm_state
290  && forward_pdf == other.forward_pdf && self_loop_pdf == other.self_loop_pdf);
291  }
292  };
293 
295 
299  std::vector<Tuple> tuples_;
300 
304  std::vector<int32> state2id_;
305 
308  std::vector<int32> id2state_;
309 
310  std::vector<int32> id2pdf_id_;
311 
314 
318 
323 
325 };
326 
328  KALDI_ASSERT(
329  static_cast<size_t>(trans_id) < id2pdf_id_.size() &&
330  "Likely graph/model mismatch (graph built from wrong model?)");
331  return id2pdf_id_[trans_id];
332 }
333 
335  // Note: it's a little dangerous to assert this only in paranoid mode.
336  // However, this function is called in the inner loop of decoders and
337  // the assertion likely takes a significant amount of time. We make
338  // sure that past the end of the id2pdf_id_ array there are big
339  // numbers, which will make the calling code more likely to segfault
340  // (rather than silently die) if this is called for out-of-range values.
342  static_cast<size_t>(trans_id) < id2pdf_id_.size() &&
343  "Likely graph/model mismatch (graph built from wrong model?)");
344  return id2pdf_id_[trans_id];
345 }
346 
356 bool GetPdfsForPhones(const TransitionModel &trans_model,
357  const std::vector<int32> &phones,
358  std::vector<int32> *pdfs);
359 
362 bool GetPhonesForPdfs(const TransitionModel &trans_model,
363  const std::vector<int32> &pdfs,
364  std::vector<int32> *phones);
366 
367 
368 } // end namespace kaldi
369 
370 
371 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
std::vector< Tuple > tuples_
Tuples indexed by transition state minus one; the tuples are in sorted order which allows us to do th...
A class for storing topology information for phones.
Definition: hmm-topology.h:93
std::vector< int32 > id2pdf_id_
const std::vector< int32 > & GetPhones() const
Returns a sorted, unique list of phones.
int32 TransitionIdToPdfFast(int32 trans_id) const
int32 num_pdfs_
This is actually one plus the highest-numbered pdf we ever got back from the tree (but the tree numbe...
kaldi::int32 int32
Vector< BaseFloat > log_probs_
For each transition-id, the corresponding log-prob. Indexed by transition-id.
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
int32 TransitionIdToPdf(int32 trans_id) const
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
bool GetPhonesForPdfs(const TransitionModel &trans_model, const std::vector< int32 > &pdfs, std::vector< int32 > *phones)
Works out which phones might correspond to the given pdfs.
void InitStats(Vector< double > *stats) const
void Register(OptionsItf *opts)
MleTransitionUpdateConfig(BaseFloat floor=0.01, BaseFloat mincount=5.0, bool share_for_pdfs=false)
int32 NumTransitionIds() const
Returns the total number of transition-ids (note, these are one-based).
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
std::vector< int32 > id2state_
For each transition-id, the corresponding transition state (indexed by transition-id).
const HmmTopology & GetTopo() const
return reference to HMM-topology object.
TransitionModel()
Constructor that takes no arguments: typically used prior to calling Read.
#define KALDI_PARANOID_ASSERT(cond)
Definition: kaldi-error.h:206
Tuple(int32 phone, int32 hmm_state, int32 forward_pdf, int32 self_loop_pdf)
std::vector< int32 > state2id_
Gives the first transition_id of each transition-state; indexed by the transition-state.
bool operator<(const Int32Pair &a, const Int32Pair &b)
Definition: cu-matrixdim.h:83
context-dep-itf.h provides a link between the tree-building code in ../tree/, and the FST code in ...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool operator==(const LatticeWeightTpl< FloatType > &wa, const LatticeWeightTpl< FloatType > &wb)
void Print(const Fst< Arc > &fst, std::string message)
void Register(OptionsItf *opts)
bool GetPdfsForPhones(const TransitionModel &trans_model, const std::vector< int32 > &phones, std::vector< int32 > *pdfs)
Works out which pdfs might correspond to the given phones.
int32 NumTransitionStates() const
Returns the total number of transition-states (note, these are one-based).
Vector< BaseFloat > non_self_loop_log_probs_
For each transition-state, the log of (1 - self-loop-prob).