transition-model.cc
Go to the documentation of this file.
1 // hmm/transition-model.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
4 // Johns Hopkins University (author: Guoguo Chen)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <vector>
22 #include "hmm/transition-model.h"
23 #include "tree/context-dep.h"
24 
25 namespace kaldi {
26 
28  if (IsHmm())
29  ComputeTuplesIsHmm(ctx_dep);
30  else
31  ComputeTuplesNotHmm(ctx_dep);
32 
33  // now tuples_ is populated with all possible tuples of (phone, hmm_state, pdf, self_loop_pdf).
34  std::sort(tuples_.begin(), tuples_.end()); // sort to enable reverse lookup.
35  // this sorting defines the transition-ids.
36 }
37 
39  const std::vector<int32> &phones = topo_.GetPhones();
40  KALDI_ASSERT(!phones.empty());
41 
42  // this is the case for normal models. but not for chain models
43  std::vector<std::vector<std::pair<int32, int32> > > pdf_info;
44  std::vector<int32> num_pdf_classes( 1 + *std::max_element(phones.begin(), phones.end()), -1);
45  for (size_t i = 0; i < phones.size(); i++)
46  num_pdf_classes[phones[i]] = topo_.NumPdfClasses(phones[i]);
47  ctx_dep.GetPdfInfo(phones, num_pdf_classes, &pdf_info);
48  // pdf_info is list indexed by pdf of which (phone, pdf_class) it
49  // can correspond to.
50 
51  std::map<std::pair<int32, int32>, std::vector<int32> > to_hmm_state_list;
52  // to_hmm_state_list is a map from (phone, pdf_class) to the list
53  // of hmm-states in the HMM for that phone that that (phone, pdf-class)
54  // can correspond to.
55  for (size_t i = 0; i < phones.size(); i++) { // setting up to_hmm_state_list.
56  int32 phone = phones[i];
58  for (int32 j = 0; j < static_cast<int32>(entry.size()); j++) { // for each state...
59  int32 pdf_class = entry[j].forward_pdf_class;
60  if (pdf_class != kNoPdf) {
61  to_hmm_state_list[std::make_pair(phone, pdf_class)].push_back(j);
62  }
63  }
64  }
65 
66  for (int32 pdf = 0; pdf < static_cast<int32>(pdf_info.size()); pdf++) {
67  for (size_t j = 0; j < pdf_info[pdf].size(); j++) {
68  int32 phone = pdf_info[pdf][j].first,
69  pdf_class = pdf_info[pdf][j].second;
70  const std::vector<int32> &state_vec = to_hmm_state_list[std::make_pair(phone, pdf_class)];
71  KALDI_ASSERT(!state_vec.empty());
72  // state_vec is a list of the possible HMM-states that emit this
73  // pdf_class.
74  for (size_t k = 0; k < state_vec.size(); k++) {
75  int32 hmm_state = state_vec[k];
76  tuples_.push_back(Tuple(phone, hmm_state, pdf, pdf));
77  }
78  }
79  }
80 }
81 
83  const std::vector<int32> &phones = topo_.GetPhones();
84  KALDI_ASSERT(!phones.empty());
85 
86  // pdf_info is a set of lists indexed by phone. Each list is indexed by
87  // (pdf-class, self-loop pdf-class) of each state of that phone, and the element
88  // is a list of possible (pdf, self-loop pdf) pairs that (pdf-class, self-loop pdf-class)
89  // pair generates.
90  std::vector<std::vector<std::vector<std::pair<int32, int32> > > > pdf_info;
91  // pdf_class_pairs is a set of lists indexed by phone. Each list stores
92  // (pdf-class, self-loop pdf-class) of each state of that phone.
93  std::vector<std::vector<std::pair<int32, int32> > > pdf_class_pairs;
94  pdf_class_pairs.resize(1 + *std::max_element(phones.begin(), phones.end()));
95  for (size_t i = 0; i < phones.size(); i++) {
96  int32 phone = phones[i];
98  for (int32 j = 0; j < static_cast<int32>(entry.size()); j++) { // for each state...
99  int32 forward_pdf_class = entry[j].forward_pdf_class, self_loop_pdf_class = entry[j].self_loop_pdf_class;
100  if (forward_pdf_class != kNoPdf)
101  pdf_class_pairs[phone].push_back(std::make_pair(forward_pdf_class, self_loop_pdf_class));
102  }
103  }
104  ctx_dep.GetPdfInfo(phones, pdf_class_pairs, &pdf_info);
105 
106  std::vector<std::map<std::pair<int32, int32>, std::vector<int32> > > to_hmm_state_list;
107  to_hmm_state_list.resize(1 + *std::max_element(phones.begin(), phones.end()));
108  // to_hmm_state_list is a phone-indexed set of maps from (pdf-class, self-loop pdf_class) to the list
109  // of hmm-states in the HMM for that phone that that (pdf-class, self-loop pdf-class)
110  // can correspond to.
111  for (size_t i = 0; i < phones.size(); i++) { // setting up to_hmm_state_list.
112  int32 phone = phones[i];
113  const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone);
114  std::map<std::pair<int32, int32>, std::vector<int32> > phone_to_hmm_state_list;
115  for (int32 j = 0; j < static_cast<int32>(entry.size()); j++) { // for each state...
116  int32 forward_pdf_class = entry[j].forward_pdf_class, self_loop_pdf_class = entry[j].self_loop_pdf_class;
117  if (forward_pdf_class != kNoPdf) {
118  phone_to_hmm_state_list[std::make_pair(forward_pdf_class, self_loop_pdf_class)].push_back(j);
119  }
120  }
121  to_hmm_state_list[phone] = phone_to_hmm_state_list;
122  }
123 
124  for (int32 i = 0; i < phones.size(); i++) {
125  int32 phone = phones[i];
126  for (int32 j = 0; j < static_cast<int32>(pdf_info[phone].size()); j++) {
127  int32 pdf_class = pdf_class_pairs[phone][j].first,
128  self_loop_pdf_class = pdf_class_pairs[phone][j].second;
129  const std::vector<int32> &state_vec =
130  to_hmm_state_list[phone][std::make_pair(pdf_class, self_loop_pdf_class)];
131  KALDI_ASSERT(!state_vec.empty());
132  for (size_t k = 0; k < state_vec.size(); k++) {
133  int32 hmm_state = state_vec[k];
134  for (size_t m = 0; m < pdf_info[phone][j].size(); m++) {
135  int32 pdf = pdf_info[phone][j][m].first,
136  self_loop_pdf = pdf_info[phone][j][m].second;
137  tuples_.push_back(Tuple(phone, hmm_state, pdf, self_loop_pdf));
138  }
139  }
140  }
141  }
142 }
143 
145  state2id_.resize(tuples_.size()+2); // indexed by transition-state, which
146  // is one based, but also an entry for one past end of list.
147 
148  int32 cur_transition_id = 1;
149  num_pdfs_ = 0;
150  for (int32 tstate = 1;
151  tstate <= static_cast<int32>(tuples_.size()+1); // not a typo.
152  tstate++) {
153  state2id_[tstate] = cur_transition_id;
154  if (static_cast<size_t>(tstate) <= tuples_.size()) {
155  int32 phone = tuples_[tstate-1].phone,
156  hmm_state = tuples_[tstate-1].hmm_state,
157  forward_pdf = tuples_[tstate-1].forward_pdf,
158  self_loop_pdf = tuples_[tstate-1].self_loop_pdf;
159  num_pdfs_ = std::max(num_pdfs_, 1 + forward_pdf);
160  num_pdfs_ = std::max(num_pdfs_, 1 + self_loop_pdf);
161  const HmmTopology::HmmState &state = topo_.TopologyForPhone(phone)[hmm_state];
162  int32 my_num_ids = static_cast<int32>(state.transitions.size());
163  cur_transition_id += my_num_ids; // # trans out of this state.
164  }
165  }
166 
167  id2state_.resize(cur_transition_id); // cur_transition_id is #transition-ids+1.
168  id2pdf_id_.resize(cur_transition_id);
169  for (int32 tstate = 1; tstate <= static_cast<int32>(tuples_.size()); tstate++) {
170  for (int32 tid = state2id_[tstate]; tid < state2id_[tstate+1]; tid++) {
171  id2state_[tid] = tstate;
172  if (IsSelfLoop(tid))
173  id2pdf_id_[tid] = tuples_[tstate-1].self_loop_pdf;
174  else
175  id2pdf_id_[tid] = tuples_[tstate-1].forward_pdf;
176  }
177  }
178 
179  // The following statements put copies a large number in the region of memory
180  // past the end of the id2pdf_id_ array, while leaving the array as it was
181  // before. The goal of this is to speed up decoding by disabling a check
182  // inside TransitionIdToPdf() that the transition-id was within the correct
183  // range.
184  int32 num_big_numbers = std::min<int32>(2000, cur_transition_id);
185  id2pdf_id_.resize(cur_transition_id + num_big_numbers,
186  std::numeric_limits<int32>::max());
187  id2pdf_id_.resize(cur_transition_id);
188 }
189 
191  log_probs_.Resize(NumTransitionIds()+1); // one-based array, zeroth element empty.
192  for (int32 trans_id = 1; trans_id <= NumTransitionIds(); trans_id++) {
193  int32 trans_state = id2state_[trans_id];
194  int32 trans_index = trans_id - state2id_[trans_state];
195  const Tuple &tuple = tuples_[trans_state-1];
197  KALDI_ASSERT(static_cast<size_t>(tuple.hmm_state) < entry.size());
198  BaseFloat prob = entry[tuple.hmm_state].transitions[trans_index].second;
199  if (prob <= 0.0)
200  KALDI_ERR << "TransitionModel::InitializeProbs, zero "
201  "probability [should remove that entry in the topology]";
202  if (prob > 1.0)
203  KALDI_WARN << "TransitionModel::InitializeProbs, prob greater than one.";
204  log_probs_(trans_id) = Log(prob);
205  }
207 }
208 
211  {
212  int32 sum = 0;
213  for (int32 ts = 1; ts <= NumTransitionStates(); ts++) sum += NumTransitionIndices(ts);
214  KALDI_ASSERT(sum == NumTransitionIds());
215  }
216  for (int32 tid = 1; tid <= NumTransitionIds(); tid++) {
217  int32 tstate = TransitionIdToTransitionState(tid),
218  index = TransitionIdToTransitionIndex(tid);
219  KALDI_ASSERT(tstate > 0 && tstate <=NumTransitionStates() && index >= 0);
220  KALDI_ASSERT(tid == PairToTransitionId(tstate, index));
221  int32 phone = TransitionStateToPhone(tstate),
222  hmm_state = TransitionStateToHmmState(tstate),
223  forward_pdf = TransitionStateToForwardPdf(tstate),
224  self_loop_pdf = TransitionStateToSelfLoopPdf(tstate);
225  KALDI_ASSERT(tstate == TupleToTransitionState(phone, hmm_state, forward_pdf, self_loop_pdf));
226  KALDI_ASSERT(log_probs_(tid) <= 0.0 && log_probs_(tid) - log_probs_(tid) == 0.0);
227  // checking finite and non-positive (and not out-of-bounds).
228  }
229 }
230 
232  const std::vector<int32> &phones = topo_.GetPhones();
233  KALDI_ASSERT(!phones.empty());
234  for (size_t i = 0; i < phones.size(); i++) {
235  int32 phone = phones[i];
236  const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone);
237  for (int32 j = 0; j < static_cast<int32>(entry.size()); j++) { // for each state...
238  if (entry[j].forward_pdf_class != entry[j].self_loop_pdf_class)
239  return false;
240  }
241  }
242  return true;
243 }
244 
246  const HmmTopology &hmm_topo): topo_(hmm_topo) {
247  // First thing is to get all possible tuples.
248  ComputeTuples(ctx_dep);
249  ComputeDerived();
250  InitializeProbs();
251  Check();
252 }
253 
254 int32 TransitionModel::TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const {
255  Tuple tuple(phone, hmm_state, pdf, self_loop_pdf);
256  // Note: if this ever gets too expensive, which is unlikely, we can refactor
257  // this code to sort first on pdf, and then index on pdf, so those
258  // that have the same pdf are in a contiguous range.
259  std::vector<Tuple>::const_iterator iter =
260  std::lower_bound(tuples_.begin(), tuples_.end(), tuple);
261  if (iter == tuples_.end() || !(*iter == tuple)) {
262  KALDI_ERR << "TransitionModel::TupleToTransitionState, tuple not found."
263  << " (incompatible tree and model?)";
264  }
265  // tuples_ is indexed by transition_state-1, so add one.
266  return static_cast<int32>((iter - tuples_.begin())) + 1;
267 }
268 
269 
271  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
272  return static_cast<int32>(state2id_[trans_state+1]-state2id_[trans_state]);
273 }
274 
276  KALDI_ASSERT(trans_id != 0 && static_cast<size_t>(trans_id) < id2state_.size());
277  return id2state_[trans_id];
278 }
279 
281  KALDI_ASSERT(trans_id != 0 && static_cast<size_t>(trans_id) < id2state_.size());
282  return trans_id - state2id_[id2state_[trans_id]];
283 }
284 
286  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
287  return tuples_[trans_state-1].phone;
288 }
289 
291  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
292  return tuples_[trans_state-1].forward_pdf;
293 }
294 
296  int32 trans_state) const {
297  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
298  const Tuple &t = tuples_[trans_state-1];
299  const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone);
300  KALDI_ASSERT(static_cast<size_t>(t.hmm_state) < entry.size());
301  return entry[t.hmm_state].forward_pdf_class;
302 }
303 
304 
306  int32 trans_state) const {
307  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
308  const Tuple &t = tuples_[trans_state-1];
309  const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone);
310  KALDI_ASSERT(static_cast<size_t>(t.hmm_state) < entry.size());
311  return entry[t.hmm_state].self_loop_pdf_class;
312 }
313 
314 
316  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
317  return tuples_[trans_state-1].self_loop_pdf;
318 }
319 
321  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
322  return tuples_[trans_state-1].hmm_state;
323 }
324 
325 int32 TransitionModel::PairToTransitionId(int32 trans_state, int32 trans_index) const {
326  KALDI_ASSERT(static_cast<size_t>(trans_state) <= tuples_.size());
327  KALDI_ASSERT(trans_index < state2id_[trans_state+1] - state2id_[trans_state]);
328  return state2id_[trans_state] + trans_index;
329 }
330 
332  int32 num_trans_state = tuples_.size();
333  int32 max_phone_id = 0;
334  for (int32 i = 0; i < num_trans_state; ++i) {
335  if (tuples_[i].phone > max_phone_id)
336  max_phone_id = tuples_[i].phone;
337  }
338  return max_phone_id;
339 }
340 
341 
342 bool TransitionModel::IsFinal(int32 trans_id) const {
343  KALDI_ASSERT(static_cast<size_t>(trans_id) < id2state_.size());
344  int32 trans_state = id2state_[trans_id];
345  int32 trans_index = trans_id - state2id_[trans_state];
346  const Tuple &tuple = tuples_[trans_state-1];
348  KALDI_ASSERT(static_cast<size_t>(tuple.hmm_state) < entry.size());
349  KALDI_ASSERT(static_cast<size_t>(tuple.hmm_state) < entry.size());
350  KALDI_ASSERT(static_cast<size_t>(trans_index) <
351  entry[tuple.hmm_state].transitions.size());
352  // return true if the transition goes to the final state of the
353  // topology entry.
354  return (entry[tuple.hmm_state].transitions[trans_index].first + 1 ==
355  static_cast<int32>(entry.size()));
356 }
357 
358 
359 
360 int32 TransitionModel::SelfLoopOf(int32 trans_state) const { // returns the self-loop transition-id,
361  KALDI_ASSERT(static_cast<size_t>(trans_state-1) < tuples_.size());
362  const Tuple &tuple = tuples_[trans_state-1];
363  // or zero if does not exist.
364  int32 phone = tuple.phone, hmm_state = tuple.hmm_state;
365  const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone);
366  KALDI_ASSERT(static_cast<size_t>(hmm_state) < entry.size());
367  for (int32 trans_index = 0;
368  trans_index < static_cast<int32>(entry[hmm_state].transitions.size());
369  trans_index++)
370  if (entry[hmm_state].transitions[trans_index].first == hmm_state)
371  return PairToTransitionId(trans_state, trans_index);
372  return 0; // invalid transition id.
373 }
374 
376  non_self_loop_log_probs_.Resize(NumTransitionStates()+1); // this array indexed
377  // by transition-state with nothing in zeroth element.
378  for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) {
379  int32 tid = SelfLoopOf(tstate);
380  if (tid == 0) { // no self-loop
381  non_self_loop_log_probs_(tstate) = 0.0; // log(1.0)
382  } else {
383  BaseFloat self_loop_prob = Exp(GetTransitionLogProb(tid)),
384  non_self_loop_prob = 1.0 - self_loop_prob;
385  if (non_self_loop_prob <= 0.0) {
386  KALDI_WARN << "ComputeDerivedOfProbs(): non-self-loop prob is " << non_self_loop_prob;
387  non_self_loop_prob = 1.0e-10; // just so we can continue...
388  }
389  non_self_loop_log_probs_(tstate) = Log(non_self_loop_prob); // will be negative.
390  }
391  }
392 }
393 
394 void TransitionModel::Read(std::istream &is, bool binary) {
395  ExpectToken(is, binary, "<TransitionModel>");
396  topo_.Read(is, binary);
397  std::string token;
398  ReadToken(is, binary, &token);
399  int32 size;
400  ReadBasicType(is, binary, &size);
401  tuples_.resize(size);
402  for (int32 i = 0; i < size; i++) {
403  ReadBasicType(is, binary, &(tuples_[i].phone));
404  ReadBasicType(is, binary, &(tuples_[i].hmm_state));
405  ReadBasicType(is, binary, &(tuples_[i].forward_pdf));
406  if (token == "<Tuples>")
407  ReadBasicType(is, binary, &(tuples_[i].self_loop_pdf));
408  else if (token == "<Triples>")
409  tuples_[i].self_loop_pdf = tuples_[i].forward_pdf;
410  }
411  ReadToken(is, binary, &token);
412  KALDI_ASSERT(token == "</Triples>" || token == "</Tuples>");
413  ComputeDerived();
414  ExpectToken(is, binary, "<LogProbs>");
415  log_probs_.Read(is, binary);
416  ExpectToken(is, binary, "</LogProbs>");
417  ExpectToken(is, binary, "</TransitionModel>");
419  Check();
420 }
421 
422 void TransitionModel::Write(std::ostream &os, bool binary) const {
423  bool is_hmm = IsHmm();
424  WriteToken(os, binary, "<TransitionModel>");
425  if (!binary) os << "\n";
426  topo_.Write(os, binary);
427  if (is_hmm)
428  WriteToken(os, binary, "<Triples>");
429  else
430  WriteToken(os, binary, "<Tuples>");
431  WriteBasicType(os, binary, static_cast<int32>(tuples_.size()));
432  if (!binary) os << "\n";
433  for (int32 i = 0; i < static_cast<int32> (tuples_.size()); i++) {
434  WriteBasicType(os, binary, tuples_[i].phone);
435  WriteBasicType(os, binary, tuples_[i].hmm_state);
436  WriteBasicType(os, binary, tuples_[i].forward_pdf);
437  if (!is_hmm)
438  WriteBasicType(os, binary, tuples_[i].self_loop_pdf);
439  if (!binary) os << "\n";
440  }
441  if (is_hmm)
442  WriteToken(os, binary, "</Triples>");
443  else
444  WriteToken(os, binary, "</Tuples>");
445  if (!binary) os << "\n";
446  WriteToken(os, binary, "<LogProbs>");
447  if (!binary) os << "\n";
448  log_probs_.Write(os, binary);
449  WriteToken(os, binary, "</LogProbs>");
450  if (!binary) os << "\n";
451  WriteToken(os, binary, "</TransitionModel>");
452  if (!binary) os << "\n";
453 }
454 
456  return Exp(log_probs_(trans_id));
457 }
458 
460  return log_probs_(trans_id);
461 }
462 
464  KALDI_ASSERT(trans_state != 0);
465  return non_self_loop_log_probs_(trans_state);
466 }
467 
469  KALDI_ASSERT(trans_id != 0);
470  KALDI_PARANOID_ASSERT(!IsSelfLoop(trans_id));
471  return log_probs_(trans_id) - GetNonSelfLoopLogProb(TransitionIdToTransitionState(trans_id));
472 }
473 
474 // stats are counts/weights, indexed by transition-id.
476  const MleTransitionUpdateConfig &cfg,
477  BaseFloat *objf_impr_out,
478  BaseFloat *count_out) {
479  if (cfg.share_for_pdfs) {
480  MleUpdateShared(stats, cfg, objf_impr_out, count_out);
481  return;
482  }
483  BaseFloat count_sum = 0.0, objf_impr_sum = 0.0;
484  int32 num_skipped = 0, num_floored = 0;
485  KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1);
486  for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) {
487  int32 n = NumTransitionIndices(tstate);
488  KALDI_ASSERT(n>=1);
489  if (n > 1) { // no point updating if only one transition...
490  Vector<double> counts(n);
491  for (int32 tidx = 0; tidx < n; tidx++) {
492  int32 tid = PairToTransitionId(tstate, tidx);
493  counts(tidx) = stats(tid);
494  }
495  double tstate_tot = counts.Sum();
496  count_sum += tstate_tot;
497  if (tstate_tot < cfg.mincount) { num_skipped++; }
498  else {
499  Vector<BaseFloat> old_probs(n), new_probs(n);
500  for (int32 tidx = 0; tidx < n; tidx++) {
501  int32 tid = PairToTransitionId(tstate, tidx);
502  old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid);
503  }
504  for (int32 tidx = 0; tidx < n; tidx++)
505  new_probs(tidx) = counts(tidx) / tstate_tot;
506  for (int32 i = 0; i < 3; i++) { // keep flooring+renormalizing for 3 times..
507  new_probs.Scale(1.0 / new_probs.Sum());
508  for (int32 tidx = 0; tidx < n; tidx++)
509  new_probs(tidx) = std::max(new_probs(tidx), cfg.floor);
510  }
511  // Compute objf change
512  for (int32 tidx = 0; tidx < n; tidx++) {
513  if (new_probs(tidx) == cfg.floor) num_floored++;
514  double objf_change = counts(tidx) * (Log(new_probs(tidx))
515  - Log(old_probs(tidx)));
516  objf_impr_sum += objf_change;
517  }
518  // Commit updated values.
519  for (int32 tidx = 0; tidx < n; tidx++) {
520  int32 tid = PairToTransitionId(tstate, tidx);
521  log_probs_(tid) = Log(new_probs(tidx));
522  if (log_probs_(tid) - log_probs_(tid) != 0.0)
523  KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?";
524  }
525  }
526  }
527  }
528  KALDI_LOG << "TransitionModel::Update, objf change is "
529  << (objf_impr_sum / count_sum) << " per frame over " << count_sum
530  << " frames. ";
531  KALDI_LOG << num_floored << " probabilities floored, " << num_skipped
532  << " out of " << NumTransitionStates() << " transition-states "
533  "skipped due to insuffient data (it is normal to have some skipped.)";
534  if (objf_impr_out) *objf_impr_out = objf_impr_sum;
535  if (count_out) *count_out = count_sum;
537 }
538 
539 
540 // stats are counts/weights, indexed by transition-id.
542  const MapTransitionUpdateConfig &cfg,
543  BaseFloat *objf_impr_out,
544  BaseFloat *count_out) {
545  KALDI_ASSERT(cfg.tau > 0.0);
546  if (cfg.share_for_pdfs) {
547  MapUpdateShared(stats, cfg, objf_impr_out, count_out);
548  return;
549  }
550  BaseFloat count_sum = 0.0, objf_impr_sum = 0.0;
551  KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1);
552  for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) {
553  int32 n = NumTransitionIndices(tstate);
554  KALDI_ASSERT(n>=1);
555  if (n > 1) { // no point updating if only one transition...
556  Vector<double> counts(n);
557  for (int32 tidx = 0; tidx < n; tidx++) {
558  int32 tid = PairToTransitionId(tstate, tidx);
559  counts(tidx) = stats(tid);
560  }
561  double tstate_tot = counts.Sum();
562  count_sum += tstate_tot;
563  Vector<BaseFloat> old_probs(n), new_probs(n);
564  for (int32 tidx = 0; tidx < n; tidx++) {
565  int32 tid = PairToTransitionId(tstate, tidx);
566  old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid);
567  }
568  for (int32 tidx = 0; tidx < n; tidx++)
569  new_probs(tidx) = (counts(tidx) + cfg.tau * old_probs(tidx)) /
570  (cfg.tau + tstate_tot);
571  // Compute objf change
572  for (int32 tidx = 0; tidx < n; tidx++) {
573  double objf_change = counts(tidx) * (Log(new_probs(tidx))
574  - Log(old_probs(tidx)));
575  objf_impr_sum += objf_change;
576  }
577  // Commit updated values.
578  for (int32 tidx = 0; tidx < n; tidx++) {
579  int32 tid = PairToTransitionId(tstate, tidx);
580  log_probs_(tid) = Log(new_probs(tidx));
581  if (log_probs_(tid) - log_probs_(tid) != 0.0)
582  KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?";
583  }
584  }
585  }
586  KALDI_LOG << "Objf change is " << (objf_impr_sum / count_sum)
587  << " per frame over " << count_sum
588  << " frames.";
589  if (objf_impr_out) *objf_impr_out = objf_impr_sum;
590  if (count_out) *count_out = count_sum;
592 }
593 
594 
595 
600  const MleTransitionUpdateConfig &cfg,
601  BaseFloat *objf_impr_out,
602  BaseFloat *count_out) {
604 
605  BaseFloat count_sum = 0.0, objf_impr_sum = 0.0;
606  int32 num_skipped = 0, num_floored = 0;
607  KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1);
608  std::map<int32, std::set<int32> > pdf_to_tstate;
609 
610  for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) {
611  int32 pdf = TransitionStateToForwardPdf(tstate);
612  pdf_to_tstate[pdf].insert(tstate);
613  if (!IsHmm()) {
614  pdf = TransitionStateToSelfLoopPdf(tstate);
615  pdf_to_tstate[pdf].insert(tstate);
616  }
617  }
618  std::map<int32, std::set<int32> >::iterator map_iter;
619  for (map_iter = pdf_to_tstate.begin();
620  map_iter != pdf_to_tstate.end();
621  ++map_iter) {
622  // map_iter->first is pdf-id... not needed.
623  const std::set<int32> &tstates = map_iter->second;
624  KALDI_ASSERT(!tstates.empty());
625  int32 one_tstate = *(tstates.begin());
626  int32 n = NumTransitionIndices(one_tstate);
627  KALDI_ASSERT(n >= 1);
628  if (n > 1) { // Only update if >1 transition...
629  Vector<double> counts(n);
630  for (std::set<int32>::const_iterator iter = tstates.begin();
631  iter != tstates.end();
632  ++iter) {
633  int32 tstate = *iter;
634  if (NumTransitionIndices(tstate) != n)
635  KALDI_ERR << "Mismatch in #transition indices: you cannot "
636  "use the --share-for-pdfs option with this topology "
637  "and sharing scheme.";
638  for (int32 tidx = 0; tidx < n; tidx++) {
639  int32 tid = PairToTransitionId(tstate, tidx);
640  counts(tidx) += stats(tid);
641  }
642  }
643  double pdf_tot = counts.Sum();
644  count_sum += pdf_tot;
645  if (pdf_tot < cfg.mincount) { num_skipped++; }
646  else {
647  // Note: when calculating objf improvement, we
648  // assume we previously had the same tying scheme so
649  // we can get the params from one_tstate and they're valid
650  // for all.
651  Vector<BaseFloat> old_probs(n), new_probs(n);
652  for (int32 tidx = 0; tidx < n; tidx++) {
653  int32 tid = PairToTransitionId(one_tstate, tidx);
654  old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid);
655  }
656  for (int32 tidx = 0; tidx < n; tidx++)
657  new_probs(tidx) = counts(tidx) / pdf_tot;
658  for (int32 i = 0; i < 3; i++) { // keep flooring+renormalizing for 3 times..
659  new_probs.Scale(1.0 / new_probs.Sum());
660  for (int32 tidx = 0; tidx < n; tidx++)
661  new_probs(tidx) = std::max(new_probs(tidx), cfg.floor);
662  }
663  // Compute objf change
664  for (int32 tidx = 0; tidx < n; tidx++) {
665  if (new_probs(tidx) == cfg.floor) num_floored++;
666  double objf_change = counts(tidx) * (Log(new_probs(tidx))
667  - Log(old_probs(tidx)));
668  objf_impr_sum += objf_change;
669  }
670  // Commit updated values.
671  for (std::set<int32>::const_iterator iter = tstates.begin();
672  iter != tstates.end();
673  ++iter) {
674  int32 tstate = *iter;
675  for (int32 tidx = 0; tidx < n; tidx++) {
676  int32 tid = PairToTransitionId(tstate, tidx);
677  log_probs_(tid) = Log(new_probs(tidx));
678  if (log_probs_(tid) - log_probs_(tid) != 0.0)
679  KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?";
680  }
681  }
682  }
683  }
684  }
685  KALDI_LOG << "Objf change is " << (objf_impr_sum / count_sum)
686  << " per frame over " << count_sum << " frames; "
687  << num_floored << " probabilities floored, "
688  << num_skipped << " pdf-ids skipped due to insuffient data.";
689  if (objf_impr_out) *objf_impr_out = objf_impr_sum;
690  if (count_out) *count_out = count_sum;
692 }
693 
694 
699  const MapTransitionUpdateConfig &cfg,
700  BaseFloat *objf_impr_out,
701  BaseFloat *count_out) {
703 
704  BaseFloat count_sum = 0.0, objf_impr_sum = 0.0;
705  KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1);
706  std::map<int32, std::set<int32> > pdf_to_tstate;
707 
708  for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) {
709  int32 pdf = TransitionStateToForwardPdf(tstate);
710  pdf_to_tstate[pdf].insert(tstate);
711  if (!IsHmm()) {
712  pdf = TransitionStateToSelfLoopPdf(tstate);
713  pdf_to_tstate[pdf].insert(tstate);
714  }
715  }
716  std::map<int32, std::set<int32> >::iterator map_iter;
717  for (map_iter = pdf_to_tstate.begin();
718  map_iter != pdf_to_tstate.end();
719  ++map_iter) {
720  // map_iter->first is pdf-id... not needed.
721  const std::set<int32> &tstates = map_iter->second;
722  KALDI_ASSERT(!tstates.empty());
723  int32 one_tstate = *(tstates.begin());
724  int32 n = NumTransitionIndices(one_tstate);
725  KALDI_ASSERT(n >= 1);
726  if (n > 1) { // Only update if >1 transition...
727  Vector<double> counts(n);
728  for (std::set<int32>::const_iterator iter = tstates.begin();
729  iter != tstates.end();
730  ++iter) {
731  int32 tstate = *iter;
732  if (NumTransitionIndices(tstate) != n)
733  KALDI_ERR << "Mismatch in #transition indices: you cannot "
734  "use the --share-for-pdfs option with this topology "
735  "and sharing scheme.";
736  for (int32 tidx = 0; tidx < n; tidx++) {
737  int32 tid = PairToTransitionId(tstate, tidx);
738  counts(tidx) += stats(tid);
739  }
740  }
741  double pdf_tot = counts.Sum();
742  count_sum += pdf_tot;
743 
744  // Note: when calculating objf improvement, we
745  // assume we previously had the same tying scheme so
746  // we can get the params from one_tstate and they're valid
747  // for all.
748  Vector<BaseFloat> old_probs(n), new_probs(n);
749  for (int32 tidx = 0; tidx < n; tidx++) {
750  int32 tid = PairToTransitionId(one_tstate, tidx);
751  old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid);
752  }
753  for (int32 tidx = 0; tidx < n; tidx++)
754  new_probs(tidx) = (counts(tidx) + old_probs(tidx) * cfg.tau) /
755  (pdf_tot + cfg.tau);
756  // Compute objf change
757  for (int32 tidx = 0; tidx < n; tidx++) {
758  double objf_change = counts(tidx) * (Log(new_probs(tidx))
759  - Log(old_probs(tidx)));
760  objf_impr_sum += objf_change;
761  }
762  // Commit updated values.
763  for (std::set<int32>::const_iterator iter = tstates.begin();
764  iter != tstates.end();
765  ++iter) {
766  int32 tstate = *iter;
767  for (int32 tidx = 0; tidx < n; tidx++) {
768  int32 tid = PairToTransitionId(tstate, tidx);
769  log_probs_(tid) = Log(new_probs(tidx));
770  if (log_probs_(tid) - log_probs_(tid) != 0.0)
771  KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?";
772  }
773  }
774  }
775  }
776  KALDI_LOG << "Objf change is " << (objf_impr_sum / count_sum)
777  << " per frame over " << count_sum
778  << " frames.";
779  if (objf_impr_out) *objf_impr_out = objf_impr_sum;
780  if (count_out) *count_out = count_sum;
782 }
783 
784 
786  KALDI_ASSERT(trans_id != 0 && static_cast<size_t>(trans_id) < id2state_.size());
787  int32 trans_state = id2state_[trans_id];
788  return tuples_[trans_state-1].phone;
789 }
790 
792  KALDI_ASSERT(trans_id != 0 && static_cast<size_t>(trans_id) < id2state_.size());
793  int32 trans_state = id2state_[trans_id];
794 
795  const Tuple &t = tuples_[trans_state-1];
797  KALDI_ASSERT(static_cast<size_t>(t.hmm_state) < entry.size());
798  if (IsSelfLoop(trans_id))
799  return entry[t.hmm_state].self_loop_pdf_class;
800  else
801  return entry[t.hmm_state].forward_pdf_class;
802 }
803 
804 
806  KALDI_ASSERT(trans_id != 0 && static_cast<size_t>(trans_id) < id2state_.size());
807  int32 trans_state = id2state_[trans_id];
808  const Tuple &t = tuples_[trans_state-1];
809  return t.hmm_state;
810 }
811 
812 void TransitionModel::Print(std::ostream &os,
813  const std::vector<std::string> &phone_names,
814  const Vector<double> *occs) {
815  if (occs != NULL)
816  KALDI_ASSERT(occs->Dim() == NumPdfs());
817  bool is_hmm = IsHmm();
818  for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) {
819  const Tuple &tuple = tuples_[tstate-1];
820  KALDI_ASSERT(static_cast<size_t>(tuple.phone) < phone_names.size());
821  std::string phone_name = phone_names[tuple.phone];
822 
823  os << "Transition-state " << tstate << ": phone = " << phone_name
824  << " hmm-state = " << tuple.hmm_state;
825  if (is_hmm)
826  os << " pdf = " << tuple.forward_pdf << '\n';
827  else
828  os << " forward-pdf = " << tuple.forward_pdf << " self-loop-pdf = "
829  << tuple.self_loop_pdf << '\n';
830  for (int32 tidx = 0; tidx < NumTransitionIndices(tstate); tidx++) {
831  int32 tid = PairToTransitionId(tstate, tidx);
832  BaseFloat p = GetTransitionProb(tid);
833  os << " Transition-id = " << tid << " p = " << p;
834  if (occs != NULL) {
835  if (IsSelfLoop(tid))
836  os << " count of pdf = " << (*occs)(tuple.self_loop_pdf);
837  else
838  os << " count of pdf = " << (*occs)(tuple.forward_pdf);
839  }
840  // now describe what it's a transition to.
841  if (IsSelfLoop(tid)) os << " [self-loop]\n";
842  else {
843  int32 hmm_state = tuple.hmm_state;
845  KALDI_ASSERT(static_cast<size_t>(hmm_state) < entry.size());
846  int32 next_hmm_state = entry[hmm_state].transitions[tidx].first;
847  KALDI_ASSERT(next_hmm_state != hmm_state);
848  os << " [" << hmm_state << " -> " << next_hmm_state << "]\n";
849  }
850  }
851  }
852 }
853 
854 bool GetPdfsForPhones(const TransitionModel &trans_model,
855  const std::vector<int32> &phones,
856  std::vector<int32> *pdfs) {
857  KALDI_ASSERT(IsSortedAndUniq(phones));
858  KALDI_ASSERT(pdfs != NULL);
859  pdfs->clear();
860  for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++) {
861  if (std::binary_search(phones.begin(), phones.end(),
862  trans_model.TransitionStateToPhone(tstate))) {
863  pdfs->push_back(trans_model.TransitionStateToForwardPdf(tstate));
864  pdfs->push_back(trans_model.TransitionStateToSelfLoopPdf(tstate));
865  }
866  }
867  SortAndUniq(pdfs);
868 
869  for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++)
870  if ((std::binary_search(pdfs->begin(), pdfs->end(),
871  trans_model.TransitionStateToForwardPdf(tstate)) ||
872  std::binary_search(pdfs->begin(), pdfs->end(),
873  trans_model.TransitionStateToSelfLoopPdf(tstate)))
874  && !std::binary_search(phones.begin(), phones.end(),
875  trans_model.TransitionStateToPhone(tstate)))
876  return false;
877  return true;
878 }
879 
880 bool GetPhonesForPdfs(const TransitionModel &trans_model,
881  const std::vector<int32> &pdfs,
882  std::vector<int32> *phones) {
884  KALDI_ASSERT(phones != NULL);
885  phones->clear();
886  for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++) {
887  if (std::binary_search(pdfs.begin(), pdfs.end(),
888  trans_model.TransitionStateToForwardPdf(tstate)) ||
889  std::binary_search(pdfs.begin(), pdfs.end(),
890  trans_model.TransitionStateToSelfLoopPdf(tstate)))
891  phones->push_back(trans_model.TransitionStateToPhone(tstate));
892  }
893  SortAndUniq(phones);
894 
895  for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++)
896  if (std::binary_search(phones->begin(), phones->end(),
897  trans_model.TransitionStateToPhone(tstate))
898  && !(std::binary_search(pdfs.begin(), pdfs.end(),
899  trans_model.TransitionStateToForwardPdf(tstate)) &&
900  std::binary_search(pdfs.begin(), pdfs.end(),
901  trans_model.TransitionStateToSelfLoopPdf(tstate))) )
902  return false;
903  return true;
904 }
905 
907  return (topo_ == other.topo_ && tuples_ == other.tuples_ &&
908  state2id_ == other.state2id_ && id2state_ == other.id2state_
909  && num_pdfs_ == other.num_pdfs_);
910 }
911 
912 bool TransitionModel::IsSelfLoop(int32 trans_id) const {
913  KALDI_ASSERT(static_cast<size_t>(trans_id) < id2state_.size());
914  int32 trans_state = id2state_[trans_id];
915  int32 trans_index = trans_id - state2id_[trans_state];
916  const Tuple &tuple = tuples_[trans_state-1];
917  int32 phone = tuple.phone, hmm_state = tuple.hmm_state;
918  const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone);
919  KALDI_ASSERT(static_cast<size_t>(hmm_state) < entry.size());
920  return (static_cast<size_t>(trans_index) < entry[hmm_state].transitions.size()
921  && entry[hmm_state].transitions[trans_index].first == hmm_state);
922 }
923 
924 } // End namespace kaldi
virtual void GetPdfInfo(const std::vector< int32 > &phones, const std::vector< int32 > &num_pdf_classes, std::vector< std::vector< std::pair< int32, int32 > > > *pdf_info) const =0
GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which pairs of (phone...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
double Exp(double x)
Definition: kaldi-math.h:83
std::vector< Tuple > tuples_
Tuples indexed by transition state minus one; the tuples are in sorted order which allows us to do th...
A structure defined inside HmmTopology to represent a HMM state.
Definition: hmm-topology.h:96
int32 PairToTransitionId(int32 trans_state, int32 trans_index) const
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
A class for storing topology information for phones.
Definition: hmm-topology.h:93
std::vector< int32 > id2pdf_id_
void ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_dep)
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 TransitionStateToSelfLoopPdfClass(int32 trans_state) const
int32 TransitionStateToForwardPdfClass(int32 trans_state) const
int32 TransitionStateToForwardPdf(int32 trans_state) const
void ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_dep)
int32 TransitionStateToHmmState(int32 trans_state) const
int32 TransitionIdToPdfClass(int32 trans_id) const
int32 SelfLoopOf(int32 trans_state) const
int32 num_pdfs_
This is actually one plus the highest-numbered pdf we ever got back from the tree (but the tree numbe...
int32 TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void Read(std::istream &is, bool binary)
Definition: hmm-topology.cc:39
Vector< BaseFloat > log_probs_
For each transition-id, the corresponding log-prob. Indexed by transition-id.
std::vector< HmmState > TopologyEntry
TopologyEntry is a typedef that represents the topology of a single (prototype) state.
Definition: hmm-topology.h:133
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq&#39;s (removes duplicates) from a vector.
Definition: stl-utils.h:39
static const int32 kNoPdf
A constant used in the HmmTopology class as the pdf-class kNoPdf, which is used when a HMM-state is n...
Definition: hmm-topology.h:86
void ComputeTuples(const ContextDependencyInterface &ctx_dep)
int32 NumPdfClasses(int32 phone) const
Returns the number of pdf-classes for this phone; throws exception if phone not covered by this topol...
void Print(std::ostream &os, const std::vector< std::string > &phone_names, const Vector< double > *occs=NULL)
Print will print the transition model in a human-readable way, for purposes of human inspection...
void Write(std::ostream &os, bool binary) const
bool GetPhonesForPdfs(const TransitionModel &trans_model, const std::vector< int32 > &pdfs, std::vector< int32 > *phones)
Works out which phones might correspond to the given pdfs.
double Log(double x)
Definition: kaldi-math.h:100
int32 NumTransitionIds() const
Returns the total number of transition-ids (note, these are one-based).
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
void Read(std::istream &is, bool binary)
int32 TransitionIdToHmmState(int32 trans_id) const
struct rnnlm::@11::@12 n
int32 NumTransitionIndices(int32 trans_state) const
Returns the number of transition-indices for a particular transition-state.
bool IsSelfLoop(int32 trans_id) const
std::vector< int32 > id2state_
For each transition-id, the corresponding transition state (indexed by transition-id).
const TopologyEntry & TopologyForPhone(int32 phone) const
Returns the topology entry (i.e.
BaseFloat GetTransitionLogProb(int32 trans_id) const
TransitionModel()
Constructor that takes no arguments: typically used prior to calling Read.
std::vector< std::pair< int32, BaseFloat > > transitions
A list of transitions, indexed by what we call a &#39;transition-index&#39;.
Definition: hmm-topology.h:109
#define KALDI_ERR
Definition: kaldi-error.h:147
int32 TransitionIdToTransitionState(int32 trans_id) const
bool Compatible(const TransitionModel &other) const
returns true if all the integer class members are identical (but does not compare the transition prob...
#define KALDI_PARANOID_ASSERT(cond)
Definition: kaldi-error.h:206
#define KALDI_WARN
Definition: kaldi-error.h:150
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
std::vector< int32 > state2id_
Gives the first transition_id of each transition-state; indexed by the transition-state.
void MapUpdateShared(const Vector< double > &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
This version of the MapUpdate() function is for if the user specifies –share-for-pdfs=true.
int32 TransitionStateToPhone(int32 trans_state) const
Real Sum() const
Returns sum of the elements.
const std::vector< int32 > & GetPhones() const
Returns a reference to a sorted, unique list of phones covered by the topology (these phones will be ...
Definition: hmm-topology.h:163
context-dep-itf.h provides a link between the tree-building code in ../tree/, and the FST code in ...
void Write(std::ostream &os, bool binary) const
BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const
Returns the log-probability of a particular non-self-loop transition after subtracting the probabilit...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
bool GetPdfsForPhones(const TransitionModel &trans_model, const std::vector< int32 > &phones, std::vector< int32 > *pdfs)
Works out which pdfs might correspond to the given phones.
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
Definition: stl-utils.h:63
#define KALDI_LOG
Definition: kaldi-error.h:153
bool IsFinal(int32 trans_id) const
int32 TransitionIdToPhone(int32 trans_id) const
int32 NumTransitionStates() const
Returns the total number of transition-states (note, these are one-based).
BaseFloat GetTransitionProb(int32 trans_id) const
Vector< BaseFloat > non_self_loop_log_probs_
For each transition-state, the log of (1 - self-loop-prob).
void MapUpdate(const Vector< double > &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum A Posteriori (MAP) estimation.
BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const
Returns the log-prob of the non-self-loop probability mass for this transition state.
int32 TransitionStateToSelfLoopPdf(int32 trans_state) const
int32 TransitionIdToTransitionIndex(int32 trans_id) const
void MleUpdateShared(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
This version of the Update() function is for if the user specifies –share-for-pdfs=true.