nnet-example-functions.cc
Go to the documentation of this file.
1 // nnet2/nnet-example-functions.cc
2 
3 // Copyright 2012-2013 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
21 #include "lat/lattice-functions.h"
22 
23 namespace kaldi {
24 namespace nnet2 {
25 
26 
28  const std::vector<int32> &alignment,
29  const Matrix<BaseFloat> &feats,
30  const CompactLattice &clat,
31  BaseFloat weight,
32  int32 left_context,
33  int32 right_context,
35  KALDI_ASSERT(left_context >= 0 && right_context >= 0);
36  int32 num_frames = alignment.size();
37  if (num_frames == 0) {
38  KALDI_WARN << "Empty alignment";
39  return false;
40  }
41  if (num_frames != feats.NumRows()) {
42  KALDI_WARN << "Dimension mismatch: alignment " << num_frames
43  << " versus feats " << feats.NumRows();
44  return false;
45  }
46  std::vector<int32> times;
47  int32 num_frames_clat = CompactLatticeStateTimes(clat, &times);
48  if (num_frames_clat != num_frames) {
49  KALDI_WARN << "Numerator/frames versus denlat frames mismatch: "
50  << num_frames << " versus " << num_frames_clat;
51  return false;
52  }
53  eg->weight = weight;
54  eg->num_ali = alignment;
55  eg->den_lat = clat;
56 
57  int32 feat_dim = feats.NumCols();
58  eg->input_frames.Resize(left_context + num_frames + right_context,
59  feat_dim);
60  eg->input_frames.Range(left_context, num_frames,
61  0, feat_dim).CopyFromMat(feats);
62 
63  // Duplicate the first and last frames.
64  for (int32 t = 0; t < left_context; t++)
65  eg->input_frames.Row(t).CopyFromVec(feats.Row(0));
66  for (int32 t = 0; t < right_context; t++)
67  eg->input_frames.Row(left_context + num_frames + t).CopyFromVec(
68  feats.Row(num_frames - 1));
69 
70  eg->left_context = left_context;
71  eg->Check();
72  return true;
73 }
74 
75 
76 
77 
78 
79 
101  public:
103  const SplitDiscriminativeExampleConfig &config,
104  const TransitionModel &tmodel,
105  const DiscriminativeNnetExample &eg,
106  std::vector<DiscriminativeNnetExample> *egs_out):
107  config_(config), tmodel_(tmodel), eg_(eg), egs_out_(egs_out) { }
108 
109  void Excise(SplitExampleStats *stats) {
110  eg_.Check();
111  PrepareLattice(false);
113  if (!config_.excise) {
114  egs_out_->resize(1);
115  (*egs_out_)[0] = eg_;
116  } else {
117  DoExcise(stats);
118  }
119  }
120 
121  void Split(SplitExampleStats *stats) {
122  if (!config_.split) {
123  egs_out_->resize(1);
124  (*egs_out_)[0] = eg_;
125  } else {
126  eg_.Check();
127  PrepareLattice(true);
129  DoSplit(stats);
130  }
131  }
132 
133  private:
134  typedef LatticeArc Arc;
136  typedef Arc::Label Label;
137 
138  // converts compact lattice to lat_. You should set first_time to true if
139  // this is being called from DoSplit, but false if being called from DoExcise
140  // (this saves some time, since we avoid some preparation steps that we know
141  // are unnecessary because they were done before
142  void PrepareLattice(bool first_time);
143 
144  void CollapseTransitionIds(); // Modifies the transition-ids on lat_ so that
145  // on each frame, there is just one with any
146  // given pdf-id. This allows us to determinize
147  // and minimize more completely.
148 
149  bool ComputeFrameInfo();
150 
151  static void RemoveAllOutputSymbols (Lattice *lat);
152 
153  void OutputOneSplit(int32 seg_begin, int32 seg_end);
154 
155  void DoSplit(SplitExampleStats *stats);
156 
157  void DoExcise(SplitExampleStats *stats);
158 
159  int32 NumFrames() const { return static_cast<int32>(eg_.num_ali.size()); }
160 
162 
163 
164  // Put in lat_out, a slice of "clat" with first frame at time "seg_begin" and
165  // with last frame at time "seg_end - 1".
166  void CreateOutputLattice(int32 seg_begin, int32 seg_end,
167  CompactLattice *clat_out);
168 
169  // Returns the state-id in this output lattice (creates a
170  // new state if needed).
171  StateId GetOutputStateId(StateId s,
172  unordered_map<StateId, StateId> *state_map,
173  Lattice *lat_out);
174 
175  struct FrameInfo {
177  int32 den_pdf_count; // number of distinct pdfs in denominator lattice
178  bool multiple_transition_ids; // true if there are multiple distinct
179  // transition-ids in the denominator lattice
180  // at this point
181  bool num_den_overlap; // true if num and den overlap.
182 
183  bool nonzero_derivative; // True if we need to keep this frame because the
184  // derivative is nonzero on this frame.
185  bool can_excise_frame; // True if the frame, if part of a segment, can be
186  // excised, *but ignoring the effect of acoustic
187  // context*. I.e. true if the likelihoods and
188  // derivatives from this frame do not matter because
189  // the derivatives are zero and the likelihoods don't
190  // affect lattice posteriors (because pdfs are all
191  // the same on this frame, or if doing mpfe,
192  // transition-ids are all the same.
193 
194  // start_state says, for a segment starting at frame t, what is the
195  // earliest state in lat_ that we have to consider including in the split
196  // lattice? This relates to a kind of optimization for efficiency.
197  StateId start_state;
198 
199  // end_state says, for a segment whose final frame is time t (i.e. whose
200  // "segment end" is time t+1), what is the latest state in lat_ that we have
201  // to consider including in the split lattice? This relates to a kind of
202  // optimization for efficiency.
203  StateId end_state;
204  FrameInfo(): den_state_count(0), den_pdf_count(0),
205  multiple_transition_ids(false),
206  num_den_overlap(false), nonzero_derivative(false),
207  can_excise_frame(false),
208  start_state(std::numeric_limits<int32>::max()), end_state(0) { }
209  };
210 
211 
212  // The following variables are set in the initializer:
216  std::vector<DiscriminativeNnetExample> *egs_out_;
217 
218  Lattice lat_; // lattice generated from eg_.den_lat, with epsilons removed etc.
219 
220 
221  // The other variables are computed by Split() or functions called from it.
222 
223  std::vector<FrameInfo> frame_info_;
224 
225  // state_times_ says, for each state in lat_, what its start time is.
226  std::vector<int32> state_times_;
227 
228 };
229 
230 // Make sure that for any given pdf-id and any given frame, the den-lat has
231 // only one transition-id mapping to that pdf-id, on the same frame.
232 // It helps us to more completely minimize the lattice. Note: we
233 // can't do this if the criterion is MPFE, because in that case the
234 // objective function will be affected by the phone-identities being
235 // different even if the pdf-ids are the same.
237  std::vector<int32> times;
238  TopSort(&lat_); // Topologically sort the lattice (required by
239  // LatticeStateTimes)
240  int32 num_frames = LatticeStateTimes(lat_, &times);
241  StateId num_states = lat_.NumStates();
242 
243  std::vector<std::map<int32, int32> > pdf_to_tid(num_frames);
244  for (StateId s = 0; s < num_states; s++) {
245  int32 t = times[s];
246  for (fst::MutableArcIterator<Lattice> aiter(&lat_, s);
247  !aiter.Done(); aiter.Next()) {
248  KALDI_ASSERT(t >= 0 && t < num_frames);
249  Arc arc = aiter.Value();
250  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
251  int32 pdf = tmodel_.TransitionIdToPdf(arc.ilabel);
252  if (pdf_to_tid[t].count(pdf) != 0) {
253  arc.ilabel = arc.olabel = pdf_to_tid[t][pdf];
254  aiter.SetValue(arc);
255  } else {
256  pdf_to_tid[t][pdf] = arc.ilabel;
257  }
258  }
259  }
260 }
261 
262 
265 
266  Project(&lat_, fst::PROJECT_INPUT); // Get rid of the word labels and put the
267  // transition-ids on both sides.
268 
269  RmEpsilon(&lat_); // Remove epsilons.. this simplifies
270  // certain things.
271 
272  if (first_time) {
275 
276  if (config_.determinize) {
277  if (!config_.minimize) {
278  Lattice det_lat;
279  Determinize(lat_, &det_lat);
280  lat_ = det_lat;
281  } else {
282  Lattice tmp_lat;
283  Reverse(lat_, &tmp_lat);
284  Determinize(tmp_lat, &lat_);
285  Reverse(lat_, &tmp_lat);
286  Determinize(tmp_lat, &lat_);
287  RmEpsilon(&lat_);
288  // Previously we determinized, then did
289  // Minimize(&lat_);
290  // but this was too slow.
291  }
292  }
293  }
294  TopSort(&lat_); // Topologically sort the lattice.
295 }
296 
297 // this function computes various arrays that say something about
298 // this frame of the lattice.
300 
301  int32 num_frames = NumFrames();
302 
303  frame_info_.clear();
304  frame_info_.resize(num_frames + 1);
305 
307 
308  std::vector<std::set<int32> > pdfs_per_frame(num_frames),
309  tids_per_frame(num_frames);
310 
311  int32 num_states = lat_.NumStates();
312 
313  for (int32 state = 0; state < num_states; state++) {
314  int32 t = state_times_[state];
315  KALDI_ASSERT(t >= 0 && t <= num_frames);
316  frame_info_[t].den_state_count++;
317  for (fst::ArcIterator<Lattice> aiter(lat_, state); !aiter.Done();
318  aiter.Next()) {
319  const LatticeArc &arc = aiter.Value();
320  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
321  int32 transition_id = arc.ilabel,
322  pdf_id = tmodel_.TransitionIdToPdf(transition_id);
323  tids_per_frame[t].insert(transition_id);
324  pdfs_per_frame[t].insert(pdf_id);
325  }
326  if (t < num_frames)
327  frame_info_[t+1].start_state = std::min(state,
328  frame_info_[t+1].start_state);
329  frame_info_[t].end_state = std::max(state,
331  }
332 
333  for (int32 i = 1; i <= NumFrames(); i++)
334  frame_info_[i].end_state = std::max(frame_info_[i-1].end_state,
336  for (int32 i = NumFrames() - 1; i >= 0; i--)
339 
340  for (int32 t = 0; t < num_frames; t++) {
341  FrameInfo &frame_info = frame_info_[t];
342  int32 transition_id = eg_.num_ali[t],
343  pdf_id = tmodel_.TransitionIdToPdf(transition_id);
344  frame_info.num_den_overlap = (pdfs_per_frame[t].count(pdf_id) != 0);
345  frame_info.multiple_transition_ids = (tids_per_frame[t].size() > 1);
346  KALDI_ASSERT(!pdfs_per_frame[t].empty());
347  frame_info.den_pdf_count = pdfs_per_frame[t].size();
348 
349  if (config_.criterion == "mpfe" || config_.criterion == "smbr") {
350  frame_info.nonzero_derivative = (frame_info.den_pdf_count > 1);
351  } else {
352  KALDI_ASSERT(config_.criterion == "mmi");
353  if (config_.drop_frames) {
354  // With frame dropping, we'll get nonzero derivative only
355  // if num and den overlap, *and* den has >1 active pdf.
356  frame_info.nonzero_derivative = frame_info.num_den_overlap &&
357  frame_info.den_state_count > 1;
358  } else {
359  // Without frame dropping, we'll get nonzero derivative if num and den
360  // do not overlap , or den has >1 active pdf.
361  frame_info.nonzero_derivative = !frame_info.num_den_overlap ||
362  frame_info.den_state_count > 1;
363  }
364  }
365  // If a frame is part of a segment, but it's not going to contribute
366  // to the derivative and the den lattice has only one pdf active
367  // at that time, then this frame can be excised from the lattice
368  // because it will not affect the posteriors around it.
369  if (config_.criterion == "mpfe") {
370  frame_info.can_excise_frame =
371  !frame_info.nonzero_derivative && \
372  !frame_info.multiple_transition_ids;
373  // in the mpfe case, if there are multiple transition-ids on a
374  // frame there may be multiple phones on a frame, which could
375  // contribute to the objective function even if they share pdf-ids.
376  // (this was an issue that came up during testing).
377  } else {
378  frame_info.can_excise_frame =
379  !frame_info.nonzero_derivative && frame_info.den_pdf_count == 1;
380  }
381  }
382  return true;
383 }
384 
385 
386 /* Excising a frame means removing a frame from the lattice and removing the
387  corresponding feature. We can only do this if it would not affect the
388  derivatives because the current frame has zero derivative and also all the
389  den-lat pdfs are the same on this frame (so removing the frame doesn't affect
390  the lattice posteriors). But we can't remove a frame if doing so would
391  affect the acoustic context. Generally speaking we must keep all frames
392  that are within LeftContext() to the left and RightContext() to the right
393  of a frame that we can't excise, *but* it's OK at the edges of a segment
394  even if they are that close to other frames, because we anyway keep a few
395  frames of context at the edges, and we can just make sure to keep the
396  *right* few frames of context.
397  */
399  int32 left_context = eg_.left_context,
400  right_context = RightContext(),
401  num_frames = NumFrames();
402  // Compute, for each frame, whether we can excise it.
403  //
404  std::vector<bool> can_excise(num_frames, false);
405 
406  bool need_some_frame = false;
407  for (int32 t = 0; t < num_frames; t++) {
408  can_excise[t] = frame_info_[t].can_excise_frame;
409  if (!can_excise[t])
410  need_some_frame = true;
411  }
412  if (!need_some_frame) { // We don't need any frame within this file, so simply
413  // delete the segment.
414  KALDI_WARN << "Example completely removed when excising."; // unexpected,
415  // as the segment should have been deleted when splitting.
416  egs_out_->clear();
417  return;
418  }
419  egs_out_->resize(1);
420  DiscriminativeNnetExample &eg_out = (*egs_out_)[0];
421 
422  // start_t and end_t will be the central part of the segment, excluding any
423  // frames at the edges that we can excise.
424  int32 start_t, end_t;
425  for (start_t = 0; can_excise[start_t]; start_t++);
426  for (end_t = num_frames; can_excise[end_t-1]; end_t--);
427 
428  // for frames from start_t to end_t-1, do not excise them if
429  // they are within the context-window of a frame that we need to keep.
430  // Note: we do t2 = t - right_context to t + left_context, because we're
431  // concerned whether frame t2 has frame t in its window... it might
432  // seem a bit backwards.
433  std::vector<bool> will_excise(can_excise);
434  for (int32 t = start_t; t < end_t; t++) {
435  for (int32 t2 = t - right_context; t2 <= t + left_context; t2++)
436  if (t2 >= start_t && t2 < end_t && !can_excise[t2])
437  will_excise[t] = false; // can't excise this frame, it's needed for
438  // context.
439  }
440 
441  // Remove all un-needed frames from the lattice by replacing the
442  // symbols with epsilon and then removing the epsilons.
443  // Note, this operation is destructive (it changes lat_).
444  int32 num_states = lat_.NumStates();
445  for (int32 state = 0; state < num_states; state++) {
446  int32 t = state_times_[state];
447  for (::fst::MutableArcIterator<Lattice> aiter(&lat_, state); !aiter.Done();
448  aiter.Next()) {
449  Arc arc = aiter.Value();
450  if (will_excise[t]) {
451  arc.ilabel = arc.olabel = 0;
452  aiter.SetValue(arc);
453  }
454  }
455  }
456  RmEpsilon(&lat_);
458  ConvertLattice(lat_, &eg_out.den_lat);
459 
460  eg_out.num_ali.clear();
461  int32 num_frames_kept = 0;
462  for (int32 t = 0; t < num_frames; t++) {
463  if (!will_excise[t]) {
464  eg_out.num_ali.push_back(eg_.num_ali[t]);
465  num_frames_kept++;
466  }
467  }
468 
469  stats->num_frames_kept_after_excise += num_frames_kept;
471  num_frames_kept);
472 
473  int32 num_frames_kept_plus = num_frames_kept + left_context + right_context;
474  eg_out.input_frames.Resize(num_frames_kept_plus,
476 
477  // the left-context of the output will be shifted to the right by
478  // start_t.
479  for (int32 i = 0; i < left_context; i++) {
480  SubVector<BaseFloat> dst(eg_out.input_frames, i);
481  SubVector<BaseFloat> src(eg_.input_frames, start_t + i);
482  dst.CopyFromVec(src);
483  }
484  // the right-context will also be shifted, we take the frames
485  // to the right of end_t.
486  for (int32 i = 0; i < right_context; i++) {
488  num_frames_kept + left_context + i);
490  end_t + left_context + i);
491  dst.CopyFromVec(src);
492  }
493  // now copy the central frames (those that were not excised).
494  int32 dst_t = 0;
495  for (int32 t = start_t; t < end_t; t++) {
496  if (!will_excise[t]) {
498  left_context + dst_t);
500  left_context + t);
501  dst.CopyFromVec(src);
502  dst_t++;
503  }
504  }
505  KALDI_ASSERT(dst_t == num_frames_kept);
506 
507 
508  eg_out.weight = eg_.weight;
509  eg_out.left_context = eg_.left_context;
510  eg_out.spk_info = eg_.spk_info;
511 
512  eg_out.Check();
513 }
514 
515 
517  std::vector<int32> split_points;
518  int32 num_frames = NumFrames();
519  {
520  // Make the "split points" 0 and num_frames, and
521  // any frame that has just one state on it and the previous
522  // frame had >1 state. This gives us one split for each
523  // "pinch point" in the lattice. Later we may move each split
524  // to a more optimal location.
525  split_points.push_back(0);
526  for (int32 t = 1; t < num_frames; t++) {
527  if (frame_info_[t].den_state_count == 1 &&
528  frame_info_[t-1].den_state_count > 1)
529  split_points.push_back(t);
530  }
531  split_points.push_back(num_frames);
532  }
533 
534  std::vector<bool> is_kept(split_points.size() - 1);
535  { // A "split" is a pair of successive split points. Work out for each split
536  // whether we must keep it (we must if it contains at least one frame for
537  // which "nonzero_derivative" == true.)
538  for (size_t s = 0; s < is_kept.size(); s++) {
539  int32 start = split_points[s], end = split_points[s+1];
540  bool keep_this_split = false;
541  for (int32 t = start; t < end; t++)
543  keep_this_split = true;
544  is_kept[s] = keep_this_split;
545  }
546  }
547 
548  egs_out_->clear();
549  egs_out_->reserve(is_kept.size());
550 
551  stats->num_lattices++;
552  stats->longest_lattice = std::max(stats->longest_lattice, num_frames);
553  stats->num_segments += is_kept.size();
554  stats->num_frames_orig += num_frames;
555  for (int32 t = 0; t < num_frames; t++)
557  stats->num_frames_must_keep++;
558 
559  for (size_t s = 0; s < is_kept.size(); s++) {
560  if (is_kept[s]) {
561  stats->num_kept_segments++;
562  OutputOneSplit(split_points[s], split_points[s+1]);
563  int32 segment_len = split_points[s+1] - split_points[s];
564  stats->num_frames_kept_after_split += segment_len;
566  std::max(stats->longest_segment_after_split, segment_len);
567  }
568  }
569 }
570 
571 
572 
574  KALDI_LOG << "Split " << num_lattices << " lattices. Stats:";
575  double kept_segs_per_lat = num_kept_segments * 1.0 / num_lattices,
576  segs_per_lat = num_segments * 1.0 / num_lattices;
577 
578  KALDI_LOG << "Made on average " << segs_per_lat << " segments per lattice, "
579  << "of which " << kept_segs_per_lat << " were kept.";
580 
581  double percent_needed = num_frames_must_keep * 100.0 / num_frames_orig,
582  percent_after_split = num_frames_kept_after_split * 100.0 / num_frames_orig,
583  percent_after_excise = num_frames_kept_after_excise * 100.0 / num_frames_orig;
584 
585  KALDI_LOG << "Needed to keep " << percent_needed << "% of frames, after split "
586  << "kept " << percent_after_split << "%, after excising frames kept "
587  << percent_after_excise << "%.";
588 
589  KALDI_LOG << "Longest lattice had " << longest_lattice
590  << " frames, longest segment after splitting had "
591  << longest_segment_after_split
592  << " frames, longest segment after excising had "
593  << longest_segment_after_excise;
594 }
595 
597  int32 seg_end) {
598  KALDI_ASSERT(seg_begin >= 0 && seg_end > seg_begin && seg_end <= NumFrames());
599  egs_out_->resize(egs_out_->size() + 1);
600  int32 left_context = eg_.left_context, right_context = RightContext(),
601  tot_context = left_context + right_context;
602  DiscriminativeNnetExample &eg_out = egs_out_->back();
603  eg_out.weight = eg_.weight;
604 
605  eg_out.num_ali.insert(eg_out.num_ali.end(),
606  eg_.num_ali.begin() + seg_begin,
607  eg_.num_ali.begin() + seg_end);
608 
609  CreateOutputLattice(seg_begin, seg_end, &(eg_out.den_lat));
610 
611  eg_out.input_frames = eg_.input_frames.Range(seg_begin, seg_end - seg_begin +
612  tot_context,
613  0, eg_.input_frames.NumCols());
614 
615  eg_out.left_context = eg_.left_context;
616 
617  eg_out.spk_info = eg_.spk_info;
618 
619  eg_out.Check();
620 }
621 
622 // static
624  for (StateId s = 0; s < lat->NumStates(); s++) {
625  for (::fst::MutableArcIterator<Lattice> aiter(lat, s); !aiter.Done();
626  aiter.Next()) {
627  Arc arc = aiter.Value();
628  arc.olabel = 0;
629  aiter.SetValue(arc);
630  }
631  }
632 }
633 
636  StateId s, unordered_map<StateId, StateId> *state_map, Lattice *lat_out) {
637  if (state_map->count(s) == 0) {
638  return ((*state_map)[s] = lat_out->AddState());
639  } else {
640  return (*state_map)[s];
641  }
642 }
643 
645  int32 seg_begin, int32 seg_end,
646  CompactLattice *clat_out) {
647  Lattice lat_out;
648 
649  // Below, state_map will map from states in the original lattice
650  // lat_ to ones in the new lattice lat_out.
651  unordered_map<StateId, StateId> state_map;
652 
653  // The range of the loop over s could be made over the
654  // entire lattice, but we limit it for efficiency.
655 
656  for (StateId s = frame_info_[seg_begin].start_state;
657  s <= frame_info_[seg_end].end_state; s++) {
658  int32 t = state_times_[s];
659 
660  if (t < seg_begin || t > seg_end) // state out of range.
661  continue;
662 
663  int32 this_state = GetOutputStateId(s, &state_map, &lat_out);
664 
665  if (t == seg_begin) // note: we only split on frames with just one
666  lat_out.SetStart(this_state); // state, so we reach this only once.
667 
668  if (t == seg_end) { // Make it final and don't process its arcs out.
669  if (seg_end == NumFrames()) {
670  lat_out.SetFinal(this_state, lat_.Final(s));
671  } else {
672  lat_out.SetFinal(this_state, LatticeWeight::One());
673  }
674  continue; // don't process arcs out of this state.
675  }
676 
677  for (fst::ArcIterator<Lattice> aiter(lat_, s); !aiter.Done(); aiter.Next()) {
678  const Arc &arc = aiter.Value();
679  StateId next_state = GetOutputStateId(arc.nextstate,
680  &state_map, &lat_out);
681  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel); // We expect no epsilons.
682  lat_out.AddArc(this_state, Arc(arc.ilabel, arc.olabel, arc.weight,
683  next_state));
684  }
685  }
686  Connect(&lat_out); // this is not really necessary, it's only to make sure
687  // the assert below fails when it should. TODO: remove it.
688  KALDI_ASSERT(lat_out.NumStates() > 0);
689  RemoveAllOutputSymbols(&lat_out);
690  ConvertLattice(lat_out, clat_out);
691 }
692 
693 /*
694 void DiscriminativeExampleSplitter::SelfTest() {
695  bool splits_ok = true; // True iff we split only
696  // on frames where there was
697  // one arc crossing.
698 
699  // we can't do any of this excising frames if we want to
700  // preserve equivalence.
701  std::fill(can_excise_.begin(), can_excise_.end(), false);
702 
703  std::vector<Lattice*> split_lats;
704 
705  int32 cur_t = NumFrames();
706  while (cur_t != 0) {
707  Backtrace this_backtrace = backtrace_[cur_t];
708  int32 prev_t = this_backtrace.prev_frame;
709 
710  int32 seg_begin = prev_t, seg_end = cur_t;
711  Lattice *new_lat = new Lattice();
712  CreateOutputLattice(seg_begin, seg_end, new_lat);
713  split_lats.push_back(new_lat);
714 
715  if (split_penalty_[cur_t] != 0)
716  splits_ok = false; // we split where there was a penalty so we don't
717  // expect equivalence.
718  cur_t = prev_t;
719  }
720  KALDI_ASSERT(!split_lats.empty());
721  std::reverse(split_lats.begin(), split_lats.end());
722  for (size_t i = 1; i < split_lats.size(); i++) {
723  // append split_lats[i] to split_lats[0], putting the
724  // result in split_lats[0].
725  Concat(split_lats[0], *(split_lats[i]));
726  }
727  Connect(split_lats[0]);
728  KALDI_ASSERT(split_lats[0]->NumStates() > 0);
729 
730 
731  if (!splits_ok) {
732  KALDI_LOG << "Not self-testing because we split where there were "
733  << "multiple paths.";
734 
735  } else {
736  if (!(RandEquivalent(*(split_lats[0]), lat_, 5, 0.01,
737  Rand(), 100))) {
738  KALDI_WARN << "Lattices were not equivalent (self-test failed).";
739  KALDI_LOG << "Original lattice was: ";
740  WriteLattice(std::cerr, false, lat_);
741  KALDI_LOG << "New lattice is:";
742  WriteLattice(std::cerr, false, *(split_lats[0]));
743  {
744  Lattice best_path_orig;
745  ShortestPath(lat_, &best_path_orig);
746  KALDI_LOG << "Original best path was:";
747  WriteLattice(std::cerr, false, best_path_orig);
748  }
749  {
750  Lattice best_path_new;
751  ShortestPath(*(split_lats[0]), &best_path_new);
752  KALDI_LOG << "New best path was:";
753  WriteLattice(std::cerr, false, best_path_new);
754  }
755  }
756  }
757  for (size_t i = 0; i < split_lats.size(); i++)
758  delete split_lats[i];
759 }
760 */
761 
762 
763 
765  const SplitDiscriminativeExampleConfig &config,
766  const TransitionModel &tmodel,
767  const DiscriminativeNnetExample &eg,
768  std::vector<DiscriminativeNnetExample> *egs_out,
769  SplitExampleStats *stats_out) {
770  DiscriminativeExampleSplitter splitter(config, tmodel, eg, egs_out);
771  splitter.Split(stats_out);
772 }
773 
774 
776  const SplitDiscriminativeExampleConfig &config,
777  const TransitionModel &tmodel,
778  const DiscriminativeNnetExample &eg,
779  std::vector<DiscriminativeNnetExample> *egs_out,
780  SplitExampleStats *stats_out) {
781  DiscriminativeExampleSplitter splitter(config, tmodel, eg, egs_out);
782  splitter.Excise(stats_out);
783 }
784 
785 
787  const TransitionModel &tmodel,
788  const DiscriminativeNnetExample &eg,
789  std::string criterion,
790  bool drop_frames,
791  bool one_silence_class,
792  Matrix<double> *hash,
793  double *num_weight,
794  double *den_weight,
795  double *tot_t) {
796  int32 feat_dim = eg.input_frames.NumCols(),
797  left_context = eg.left_context,
798  num_frames = eg.num_ali.size(),
799  right_context = eg.input_frames.NumRows() - num_frames - left_context,
800  context_width = left_context + 1 + right_context;
801  *tot_t += num_frames;
802  KALDI_ASSERT(right_context >= 0);
803  KALDI_ASSERT(hash != NULL);
804  if (hash->NumRows() == 0) {
805  hash->Resize(tmodel.NumPdfs(), feat_dim);
806  } else {
807  KALDI_ASSERT(hash->NumRows() == tmodel.NumPdfs() &&
808  hash->NumCols() == feat_dim);
809  }
810 
811  Posterior post;
812  std::vector<int32> silence_phones; // we don't let the user specify this
813  // because it's not necessary for testing
814  // purposes -> leave it empty
815  ExampleToPdfPost(tmodel, silence_phones, criterion, drop_frames,
816  one_silence_class, eg, &post);
817 
818  Vector<BaseFloat> avg_feat(feat_dim);
819 
820  for (int32 t = 0; t < num_frames; t++) {
821  SubMatrix<BaseFloat> context_window(eg.input_frames,
822  t, context_width,
823  0, feat_dim);
824  // set avg_feat to average over the context-window for this frame.
825  avg_feat.AddRowSumMat(1.0 / context_width, context_window, 0.0);
826  Vector<double> avg_feat_dbl(avg_feat);
827  for (size_t i = 0; i < post[t].size(); i++) {
828  int32 pdf_id = post[t][i].first;
829  BaseFloat weight = post[t][i].second;
830  hash->Row(pdf_id).AddVec(weight, avg_feat_dbl);
831  if (weight > 0.0) *num_weight += weight;
832  else *den_weight += -weight;
833  }
834  }
835 }
836 
837 
839  const TransitionModel &tmodel,
840  const std::vector<int32> &silence_phones,
841  std::string criterion,
842  bool drop_frames,
843  bool one_silence_class,
844  const DiscriminativeNnetExample &eg,
845  Posterior *post) {
846  KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr" || criterion == "mmi");
847 
848  Lattice lat;
849  ConvertLattice(eg.den_lat, &lat);
850  TopSort(&lat);
851  if (criterion == "mpfe" || criterion == "smbr") {
852  Posterior tid_post;
853  LatticeForwardBackwardMpeVariants(tmodel, silence_phones, lat, eg.num_ali,
854  criterion, one_silence_class, &tid_post);
855 
856  ConvertPosteriorToPdfs(tmodel, tid_post, post);
857  } else {
858  bool convert_to_pdf_ids = true, cancel = true;
859  LatticeForwardBackwardMmi(tmodel, lat, eg.num_ali,
860  drop_frames, convert_to_pdf_ids, cancel,
861  post);
862  }
863  ScalePosterior(eg.weight, post);
864 }
865 
866 
868  const std::vector<BaseFloat> &costs,
869  std::vector<std::vector<size_t> > *groups) {
870  groups->clear();
871  std::vector<BaseFloat> group_costs;
872  for (size_t i = 0; i < costs.size(); i++) {
873  bool found_group = false;
874  BaseFloat this_cost = costs[i];
875  for (size_t j = 0; j < groups->size(); j++) {
876  if (group_costs[j] + this_cost <= max_cost) {
877  (*groups)[j].push_back(i);
878  group_costs[j] += this_cost;
879  found_group = true;
880  break;
881  }
882  }
883  if (!found_group) { // Put this object in a newly created group.
884  groups->resize(groups->size() + 1);
885  groups->back().push_back(i);
886  group_costs.push_back(this_cost);
887  }
888  }
889 }
890 
892  const std::vector<const DiscriminativeNnetExample*> &input,
893  DiscriminativeNnetExample *output) {
894  KALDI_ASSERT(!input.empty());
895  const DiscriminativeNnetExample &eg0 = *(input[0]);
896 
897  int32 dim = eg0.input_frames.NumCols() + eg0.spk_info.Dim(),
898  left_context = eg0.left_context,
899  num_frames = eg0.num_ali.size(),
900  right_context = eg0.input_frames.NumRows() - num_frames - left_context;
901 
902  int32 tot_frames = eg0.input_frames.NumRows(); // total frames (appended,
903  // with context)
904  for (size_t i = 1; i < input.size(); i++)
905  tot_frames += input[i]->input_frames.NumRows();
906 
907  int32 arbitrary_tid = 1; // arbitrary transition-id that we use to pad the
908  // num_ali and den_lat members between segments
909  // (since they're both the same, and the den-lat in
910  // those parts is linear, they contribute no
911  // derivative to the training).
912 
913  output->den_lat = eg0.den_lat;
914  output->num_ali = eg0.num_ali;
915  output->input_frames.Resize(tot_frames, dim, kUndefined);
916  output->input_frames.Range(0, eg0.input_frames.NumRows(),
917  0, eg0.input_frames.NumCols()).CopyFromMat(eg0.input_frames);
918  if (eg0.spk_info.Dim() != 0) {
919  output->input_frames.Range(0, eg0.input_frames.NumRows(),
920  eg0.input_frames.NumCols(), eg0.spk_info.Dim()).
921  CopyRowsFromVec(eg0.spk_info);
922  }
923 
924  output->num_ali.reserve(tot_frames - left_context - right_context);
925  output->weight = eg0.weight;
926  output->left_context = eg0.left_context;
927  output->spk_info.Resize(0);
928 
929  CompactLattice inter_segment_clat;
930  int32 initial = inter_segment_clat.AddState(); // state 0.
931  inter_segment_clat.SetStart(initial);
932 
933  std::vector<int32> inter_segment_ali(left_context + right_context);
934  std::fill(inter_segment_ali.begin(), inter_segment_ali.end(), arbitrary_tid);
935 
937  final_weight.SetString(inter_segment_ali);
938  inter_segment_clat.SetFinal(initial, final_weight);
939 
940  int32 feat_offset = eg0.input_frames.NumRows();
941 
942  for (size_t i = 1; i < input.size(); i++) {
943  const DiscriminativeNnetExample &eg_i = *(input[i]);
944 
945  output->input_frames.Range(feat_offset, eg_i.input_frames.NumRows(),
946  0, eg_i.input_frames.NumCols()).CopyFromMat(
947  eg_i.input_frames);
948  if (eg_i.spk_info.Dim() != 0) {
949  output->input_frames.Range(feat_offset, eg_i.input_frames.NumRows(),
950  eg_i.input_frames.NumCols(),
951  eg_i.spk_info.Dim()).CopyRowsFromVec(
952  eg_i.spk_info);
954  eg_i.spk_info.Dim() == dim);
955  }
956 
957  output->num_ali.insert(output->num_ali.end(),
958  inter_segment_ali.begin(), inter_segment_ali.end());
959  output->num_ali.insert(output->num_ali.end(),
960  eg_i.num_ali.begin(), eg_i.num_ali.end());
961  Concat(&(output->den_lat), inter_segment_clat);
962  Concat(&(output->den_lat), eg_i.den_lat);
963  KALDI_ASSERT(output->weight == eg_i.weight);
964  KALDI_ASSERT(output->left_context == eg_i.left_context);
965  feat_offset += eg_i.input_frames.NumRows();
966  }
967  KALDI_ASSERT(feat_offset == tot_frames);
968 }
969 
971  int32 max_length,
972  const std::vector<DiscriminativeNnetExample> &input,
973  std::vector<DiscriminativeNnetExample> *output) {
974 
975  std::vector<BaseFloat> costs(input.size());
976  for (size_t i = 0; i < input.size(); i++)
977  costs[i] = static_cast<BaseFloat>(input[i].input_frames.NumRows());
978  std::vector<std::vector<size_t> > groups;
979  SolvePackingProblem(max_length,
980  costs,
981  &groups);
982  output->clear();
983  output->resize(groups.size());
984  for (size_t i = 0; i < groups.size(); i++) {
985  std::vector<const DiscriminativeNnetExample*> group_egs;
986  for (size_t j = 0; j < groups[i].size(); j++) {
987  size_t index = groups[i][j];
988  group_egs.push_back(&(input[index]));
989  }
990  AppendDiscriminativeExamples(group_egs, &((*output)[i]));
991  }
992 }
993 
994 
995 
996 } // namespace nnet2
997 } // namespace kaldi
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
const SplitDiscriminativeExampleConfig & config_
void AddRowSumMat(Real alpha, const MatrixBase< Real > &M, Real beta=1.0)
Does *this = alpha * (sum of rows of M) + beta * *this.
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
static const LatticeWeightTpl One()
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void SplitDiscriminativeExample(const SplitDiscriminativeExampleConfig &config, const TransitionModel &tmodel, const DiscriminativeNnetExample &eg, std::vector< DiscriminativeNnetExample > *egs_out, SplitExampleStats *stats_out)
Split a "discriminative example" into multiple pieces, splitting where the lattice has "pinch points"...
void ExciseDiscriminativeExample(const SplitDiscriminativeExampleConfig &config, const TransitionModel &tmodel, const DiscriminativeNnetExample &eg, std::vector< DiscriminativeNnetExample > *egs_out, SplitExampleStats *stats_out)
Remove unnecessary frames from discriminative training example.
void UpdateHash(const TransitionModel &tmodel, const DiscriminativeNnetExample &eg, std::string criterion, bool drop_frames, bool one_silence_class, Matrix< double > *hash, double *num_weight, double *den_weight, double *tot_t)
This function is used in code that tests the functionality that we provide here, about splitting and ...
kaldi::int32 int32
void CombineDiscriminativeExamples(int32 max_length, const std::vector< DiscriminativeNnetExample > &input, std::vector< DiscriminativeNnetExample > *output)
This function is used to combine multiple discriminative-training examples (each corresponding to a s...
int32 TransitionIdToPdf(int32 trans_id) const
This struct exists only for diagnostic purposes.
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
void ExampleToPdfPost(const TransitionModel &tmodel, const std::vector< int32 > &silence_phones, std::string criterion, bool drop_frames, bool one_silence_class, const DiscriminativeNnetExample &eg, Posterior *post)
Given a discriminative training example, this function works out posteriors at the pdf level (note: t...
BaseFloat LatticeForwardBackwardMmi(const TransitionModel &tmodel, const Lattice &lat, const std::vector< int32 > &num_ali, bool drop_frames, bool convert_to_pdf_ids, bool cancel, Posterior *post)
This function can be used to compute posteriors for MMI, with a positive contribution for the numerat...
DiscriminativeExampleSplitter(const SplitDiscriminativeExampleConfig &config, const TransitionModel &tmodel, const DiscriminativeNnetExample &eg, std::vector< DiscriminativeNnetExample > *egs_out)
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
bool LatticeToDiscriminativeExample(const std::vector< int32 > &alignment, const Matrix< BaseFloat > &feats, const CompactLattice &clat, BaseFloat weight, int32 left_context, int32 right_context, DiscriminativeNnetExample *eg)
Converts lattice to discriminative training example.
static const CompactLatticeWeightTpl< WeightType, IntType > One()
void AppendDiscriminativeExamples(const std::vector< const DiscriminativeNnetExample *> &input, DiscriminativeNnetExample *output)
Appends the given vector of examples (which must be non-empty) into a single output example (called b...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
Vector< BaseFloat > spk_info
spk_info contains any component of the features that varies slowly or not at all with time (and hence...
Definition: nnet-example.h:171
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int32 CompactLatticeStateTimes(const CompactLattice &lat, vector< int32 > *times)
As LatticeStateTimes, but in the CompactLattice format.
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
CompactLattice den_lat
The denominator lattice.
Definition: nnet-example.h:148
#define KALDI_WARN
Definition: kaldi-error.h:150
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
fst::StdArc::Label Label
BaseFloat weight
The weight we assign to this example; this will typically be one, but we include it for the sake of g...
Definition: nnet-example.h:140
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
void ScalePosterior(BaseFloat scale, Posterior *post)
Scales the BaseFloat (weight) element in the posterior entries.
Definition: posterior.cc:218
This struct is used to store the information we need for discriminative training (MMI or MPE)...
Definition: nnet-example.h:136
A class representing a vector.
Definition: kaldi-vector.h:406
StateId GetOutputStateId(StateId s, unordered_map< StateId, StateId > *state_map, Lattice *lat_out)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void OutputOneSplit(int32 seg_begin, int32 seg_end)
SubMatrix< Real > Range(const MatrixIndexT row_offset, const MatrixIndexT num_rows, const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Return a sub-part of matrix.
Definition: kaldi-matrix.h:202
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
int32 left_context
The number of frames of left context in the features (we can work out the #frames of right context fr...
Definition: nnet-example.h:164
#define KALDI_LOG
Definition: kaldi-error.h:153
Note on how to parse this filename: it contains functions relatied to neural-net training examples...
void CreateOutputLattice(int32 seg_begin, int32 seg_end, CompactLattice *clat_out)
Sub-matrix representation.
Definition: kaldi-matrix.h:988
void SolvePackingProblem(BaseFloat max_cost, const std::vector< BaseFloat > &costs, std::vector< std::vector< size_t > > *groups)
This function solves the "packing problem" using the "first fit" algorithm.
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
std::vector< DiscriminativeNnetExample > * egs_out_
Config structure for SplitExample, for splitting discriminative training examples.