21 #include "lat/lattice-functions.h"
23 namespace kaldi {
24 namespace nnet2 {
28  const std::vector<int32> &alignment,
29  const Matrix<BaseFloat> &feats,
30  const CompactLattice &clat,
31  BaseFloat weight,
32  int32 left_context,
33  int32 right_context,
35  KALDI_ASSERT(left_context >= 0 && right_context >= 0);
36  int32 num_frames = alignment.size();
37  if (num_frames == 0) {
38  KALDI_WARN << "Empty alignment";
39  return false;
40  }
41  if (num_frames != feats.NumRows()) {
42  KALDI_WARN << "Dimension mismatch: alignment " << num_frames
43  << " versus feats " << feats.NumRows();
44  return false;
45  }
46  std::vector<int32> times;
47  int32 num_frames_clat = CompactLatticeStateTimes(clat, &times);
48  if (num_frames_clat != num_frames) {
49  KALDI_WARN << "Numerator/frames versus denlat frames mismatch: "
50  << num_frames << " versus " << num_frames_clat;
51  return false;
52  }
53  eg->weight = weight;
54  eg->num_ali = alignment;
55  eg->den_lat = clat;
57  int32 feat_dim = feats.NumCols();
58  eg->input_frames.Resize(left_context + num_frames + right_context,
59  feat_dim);
60  eg->input_frames.Range(left_context, num_frames,
61  0, feat_dim).CopyFromMat(feats);
63  // Duplicate the first and last frames.
64  for (int32 t = 0; t < left_context; t++)
65  eg->input_frames.Row(t).CopyFromVec(feats.Row(0));
66  for (int32 t = 0; t < right_context; t++)
67  eg->input_frames.Row(left_context + num_frames + t).CopyFromVec(
68  feats.Row(num_frames - 1));
70  eg->left_context = left_context;
71  eg->Check();
72  return true;
73 }
101  public:
103  const SplitDiscriminativeExampleConfig &config,
104  const TransitionModel &tmodel,
105  const DiscriminativeNnetExample &eg,
106  std::vector<DiscriminativeNnetExample> *egs_out):
107  config_(config), tmodel_(tmodel), eg_(eg), egs_out_(egs_out) { }
109  void Excise(SplitExampleStats *stats) {
110  eg_.Check();
111  PrepareLattice(false);
113  if (!config_.excise) {
114  egs_out_->resize(1);
115  (*egs_out_)[0] = eg_;
116  } else {
117  DoExcise(stats);
118  }
119  }
121  void Split(SplitExampleStats *stats) {
122  if (!config_.split) {
123  egs_out_->resize(1);
124  (*egs_out_)[0] = eg_;
125  } else {
126  eg_.Check();
127  PrepareLattice(true);
129  DoSplit(stats);
130  }
131  }
133  private:
134  typedef LatticeArc Arc;
136  typedef Arc::Label Label;
138  // converts compact lattice to lat_. You should set first_time to true if
139  // this is being called from DoSplit, but false if being called from DoExcise
140  // (this saves some time, since we avoid some preparation steps that we know
141  // are unnecessary because they were done before
142  void PrepareLattice(bool first_time);
144  void CollapseTransitionIds(); // Modifies the transition-ids on lat_ so that
145  // on each frame, there is just one with any
146  // given pdf-id. This allows us to determinize
147  // and minimize more completely.
149  bool ComputeFrameInfo();
151  static void RemoveAllOutputSymbols (Lattice *lat);
153  void OutputOneSplit(int32 seg_begin, int32 seg_end);
155  void DoSplit(SplitExampleStats *stats);
157  void DoExcise(SplitExampleStats *stats);
159  int32 NumFrames() const { return static_cast<int32>(eg_.num_ali.size()); }
164  // Put in lat_out, a slice of "clat" with first frame at time "seg_begin" and
165  // with last frame at time "seg_end - 1".
166  void CreateOutputLattice(int32 seg_begin, int32 seg_end,
167  CompactLattice *clat_out);
169  // Returns the state-id in this output lattice (creates a
170  // new state if needed).
171  StateId GetOutputStateId(StateId s,
172  unordered_map<StateId, StateId> *state_map,
173  Lattice *lat_out);
175  struct FrameInfo {
177  int32 den_pdf_count; // number of distinct pdfs in denominator lattice
178  bool multiple_transition_ids; // true if there are multiple distinct
179  // transition-ids in the denominator lattice
180  // at this point
181  bool num_den_overlap; // true if num and den overlap.
183  bool nonzero_derivative; // True if we need to keep this frame because the
184  // derivative is nonzero on this frame.
185  bool can_excise_frame; // True if the frame, if part of a segment, can be
186  // excised, *but ignoring the effect of acoustic
187  // context*. I.e. true if the likelihoods and
188  // derivatives from this frame do not matter because
189  // the derivatives are zero and the likelihoods don't
190  // affect lattice posteriors (because pdfs are all
191  // the same on this frame, or if doing mpfe,
192  // transition-ids are all the same.
194  // start_state says, for a segment starting at frame t, what is the
195  // earliest state in lat_ that we have to consider including in the split
196  // lattice? This relates to a kind of optimization for efficiency.
197  StateId start_state;
199  // end_state says, for a segment whose final frame is time t (i.e. whose
200  // "segment end" is time t+1), what is the latest state in lat_ that we have
201  // to consider including in the split lattice? This relates to a kind of
202  // optimization for efficiency.
203  StateId end_state;
204  FrameInfo(): den_state_count(0), den_pdf_count(0),
205  multiple_transition_ids(false),
206  num_den_overlap(false), nonzero_derivative(false),
207  can_excise_frame(false),
208  start_state(std::numeric_limits<int32>::max()), end_state(0) { }
209  };
212  // The following variables are set in the initializer:
216  std::vector<DiscriminativeNnetExample> *egs_out_;
218  Lattice lat_; // lattice generated from eg_.den_lat, with epsilons removed etc.
221  // The other variables are computed by Split() or functions called from it.
223  std::vector<FrameInfo> frame_info_;
225  // state_times_ says, for each state in lat_, what its start time is.
226  std::vector<int32> state_times_;
228 };
230 // Make sure that for any given pdf-id and any given frame, the den-lat has
231 // only one transition-id mapping to that pdf-id, on the same frame.
232 // It helps us to more completely minimize the lattice. Note: we
233 // can't do this if the criterion is MPFE, because in that case the
234 // objective function will be affected by the phone-identities being
235 // different even if the pdf-ids are the same.
237  std::vector<int32> times;
238  TopSort(&lat_); // Topologically sort the lattice (required by
239  // LatticeStateTimes)
240  int32 num_frames = LatticeStateTimes(lat_, &times);
241  StateId num_states = lat_.NumStates();
243  std::vector<std::map<int32, int32> > pdf_to_tid(num_frames);
244  for (StateId s = 0; s < num_states; s++) {
245  int32 t = times[s];
246  for (fst::MutableArcIterator<Lattice> aiter(&lat_, s);
247  !aiter.Done(); aiter.Next()) {
248  KALDI_ASSERT(t >= 0 && t < num_frames);
249  Arc arc = aiter.Value();
250  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
251  int32 pdf = tmodel_.TransitionIdToPdf(arc.ilabel);
252  if (pdf_to_tid[t].count(pdf) != 0) {
253  arc.ilabel = arc.olabel = pdf_to_tid[t][pdf];
254  aiter.SetValue(arc);
255  } else {
256  pdf_to_tid[t][pdf] = arc.ilabel;
257  }
258  }
259  }
260 }
266  Project(&lat_, fst::PROJECT_INPUT); // Get rid of the word labels and put the
267  // transition-ids on both sides.
269  RmEpsilon(&lat_); // Remove epsilons.. this simplifies
270  // certain things.
272  if (first_time) {
276  if (config_.determinize) {
277  if (!config_.minimize) {
278  Lattice det_lat;
279  Determinize(lat_, &det_lat);
280  lat_ = det_lat;
281  } else {
282  Lattice tmp_lat;
283  Reverse(lat_, &tmp_lat);
284  Determinize(tmp_lat, &lat_);
285  Reverse(lat_, &tmp_lat);
286  Determinize(tmp_lat, &lat_);
287  RmEpsilon(&lat_);
288  // Previously we determinized, then did
289  // Minimize(&lat_);
290  // but this was too slow.
291  }
292  }
293  }
294  TopSort(&lat_); // Topologically sort the lattice.
295 }
297 // this function computes various arrays that say something about
298 // this frame of the lattice.
301  int32 num_frames = NumFrames();
303  frame_info_.clear();
304  frame_info_.resize(num_frames + 1);
308  std::vector<std::set<int32> > pdfs_per_frame(num_frames),
309  tids_per_frame(num_frames);
311  int32 num_states = lat_.NumStates();
313  for (int32 state = 0; state < num_states; state++) {
314  int32 t = state_times_[state];
315  KALDI_ASSERT(t >= 0 && t <= num_frames);
316  frame_info_[t].den_state_count++;
317  for (fst::ArcIterator<Lattice> aiter(lat_, state); !aiter.Done();
318  aiter.Next()) {
319  const LatticeArc &arc = aiter.Value();
320  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
321  int32 transition_id = arc.ilabel,
322  pdf_id = tmodel_.TransitionIdToPdf(transition_id);
323  tids_per_frame[t].insert(transition_id);
324  pdfs_per_frame[t].insert(pdf_id);
325  }
326  if (t < num_frames)
327  frame_info_[t+1].start_state = std::min(state,
328  frame_info_[t+1].start_state);
329  frame_info_[t].end_state = std::max(state,
331  }
333  for (int32 i = 1; i <= NumFrames(); i++)
334  frame_info_[i].end_state = std::max(frame_info_[i-1].end_state,
336  for (int32 i = NumFrames() - 1; i >= 0; i--)
340  for (int32 t = 0; t < num_frames; t++) {
341  FrameInfo &frame_info = frame_info_[t];
342  int32 transition_id = eg_.num_ali[t],
343  pdf_id = tmodel_.TransitionIdToPdf(transition_id);
344  frame_info.num_den_overlap = (pdfs_per_frame[t].count(pdf_id) != 0);
345  frame_info.multiple_transition_ids = (tids_per_frame[t].size() > 1);
346  KALDI_ASSERT(!pdfs_per_frame[t].empty());
347  frame_info.den_pdf_count = pdfs_per_frame[t].size();
349  if (config_.criterion == "mpfe" || config_.criterion == "smbr") {
350  frame_info.nonzero_derivative = (frame_info.den_pdf_count > 1);
351  } else {
352  KALDI_ASSERT(config_.criterion == "mmi");
353  if (config_.drop_frames) {
354  // With frame dropping, we'll get nonzero derivative only
355  // if num and den overlap, *and* den has >1 active pdf.
356  frame_info.nonzero_derivative = frame_info.num_den_overlap &&
357  frame_info.den_state_count > 1;
358  } else {
359  // Without frame dropping, we'll get nonzero derivative if num and den
360  // do not overlap , or den has >1 active pdf.
361  frame_info.nonzero_derivative = !frame_info.num_den_overlap ||
362  frame_info.den_state_count > 1;
363  }
364  }
365  // If a frame is part of a segment, but it's not going to contribute
366  // to the derivative and the den lattice has only one pdf active
367  // at that time, then this frame can be excised from the lattice
368  // because it will not affect the posteriors around it.
369  if (config_.criterion == "mpfe") {
370  frame_info.can_excise_frame =
371  !frame_info.nonzero_derivative && \
372  !frame_info.multiple_transition_ids;
373  // in the mpfe case, if there are multiple transition-ids on a
374  // frame there may be multiple phones on a frame, which could
375  // contribute to the objective function even if they share pdf-ids.
376  // (this was an issue that came up during testing).
377  } else {
378  frame_info.can_excise_frame =
379  !frame_info.nonzero_derivative && frame_info.den_pdf_count == 1;
380  }
381  }
382  return true;
383 }
386 /* Excising a frame means removing a frame from the lattice and removing the
387  corresponding feature. We can only do this if it would not affect the
388  derivatives because the current frame has zero derivative and also all the
389  den-lat pdfs are the same on this frame (so removing the frame doesn't affect
390  the lattice posteriors). But we can't remove a frame if doing so would
391  affect the acoustic context. Generally speaking we must keep all frames
392  that are within LeftContext() to the left and RightContext() to the right
393  of a frame that we can't excise, *but* it's OK at the edges of a segment
394  even if they are that close to other frames, because we anyway keep a few
395  frames of context at the edges, and we can just make sure to keep the
396  *right* few frames of context.
397  */
399  int32 left_context = eg_.left_context,
400  right_context = RightContext(),
401  num_frames = NumFrames();
402  // Compute, for each frame, whether we can excise it.
403  //
404  std::vector<bool> can_excise(num_frames, false);
406  bool need_some_frame = false;
407  for (int32 t = 0; t < num_frames; t++) {
408  can_excise[t] = frame_info_[t].can_excise_frame;
409  if (!can_excise[t])
410  need_some_frame = true;
411  }
412  if (!need_some_frame) { // We don't need any frame within this file, so simply
413  // delete the segment.
414  KALDI_WARN << "Example completely removed when excising."; // unexpected,
415  // as the segment should have been deleted when splitting.
416  egs_out_->clear();
417  return;
418  }
419  egs_out_->resize(1);
420  DiscriminativeNnetExample &eg_out = (*egs_out_)[0];
422  // start_t and end_t will be the central part of the segment, excluding any
423  // frames at the edges that we can excise.
424  int32 start_t, end_t;
425  for (start_t = 0; can_excise[start_t]; start_t++);
426  for (end_t = num_frames; can_excise[end_t-1]; end_t--);
428  // for frames from start_t to end_t-1, do not excise them if
429  // they are within the context-window of a frame that we need to keep.
430  // Note: we do t2 = t - right_context to t + left_context, because we're
431  // concerned whether frame t2 has frame t in its window... it might
432  // seem a bit backwards.
433  std::vector<bool> will_excise(can_excise);
434  for (int32 t = start_t; t < end_t; t++) {
435  for (int32 t2 = t - right_context; t2 <= t + left_context; t2++)
436  if (t2 >= start_t && t2 < end_t && !can_excise[t2])
437  will_excise[t] = false; // can't excise this frame, it's needed for
438  // context.
439  }
441  // Remove all un-needed frames from the lattice by replacing the
442  // symbols with epsilon and then removing the epsilons.
443  // Note, this operation is destructive (it changes lat_).
444  int32 num_states = lat_.NumStates();
445  for (int32 state = 0; state < num_states; state++) {
446  int32 t = state_times_[state];
447  for (::fst::MutableArcIterator<Lattice> aiter(&lat_, state); !aiter.Done();
448  aiter.Next()) {
449  Arc arc = aiter.Value();
450  if (will_excise[t]) {
451  arc.ilabel = arc.olabel = 0;
452  aiter.SetValue(arc);
453  }
454  }
455  }
456  RmEpsilon(&lat_);
458  ConvertLattice(lat_, &eg_out.den_lat);
460  eg_out.num_ali.clear();
461  int32 num_frames_kept = 0;
462  for (int32 t = 0; t < num_frames; t++) {
463  if (!will_excise[t]) {
464  eg_out.num_ali.push_back(eg_.num_ali[t]);
465  num_frames_kept++;
466  }
467  }
469  stats->num_frames_kept_after_excise += num_frames_kept;
471  num_frames_kept);
473  int32 num_frames_kept_plus = num_frames_kept + left_context + right_context;
474  eg_out.input_frames.Resize(num_frames_kept_plus,
477  // the left-context of the output will be shifted to the right by
478  // start_t.
479  for (int32 i = 0; i < left_context; i++) {
480  SubVector<BaseFloat> dst(eg_out.input_frames, i);
481  SubVector<BaseFloat> src(eg_.input_frames, start_t + i);
482  dst.CopyFromVec(src);
483  }
484  // the right-context will also be shifted, we take the frames
485  // to the right of end_t.
486  for (int32 i = 0; i < right_context; i++) {
488  num_frames_kept + left_context + i);
490  end_t + left_context + i);
491  dst.CopyFromVec(src);
492  }
493  // now copy the central frames (those that were not excised).
494  int32 dst_t = 0;
495  for (int32 t = start_t; t < end_t; t++) {
496  if (!will_excise[t]) {
498  left_context + dst_t);
500  left_context + t);
501  dst.CopyFromVec(src);
502  dst_t++;
503  }
504  }
505  KALDI_ASSERT(dst_t == num_frames_kept);
508  eg_out.weight = eg_.weight;
509  eg_out.left_context = eg_.left_context;
510  eg_out.spk_info = eg_.spk_info;
512  eg_out.Check();
513 }
517  std::vector<int32> split_points;
518  int32 num_frames = NumFrames();
519  {
520  // Make the "split points" 0 and num_frames, and
521  // any frame that has just one state on it and the previous
522  // frame had >1 state. This gives us one split for each
523  // "pinch point" in the lattice. Later we may move each split
524  // to a more optimal location.
525  split_points.push_back(0);
526  for (int32 t = 1; t < num_frames; t++) {
527  if (frame_info_[t].den_state_count == 1 &&
528  frame_info_[t-1].den_state_count > 1)
529  split_points.push_back(t);
530  }
531  split_points.push_back(num_frames);
532  }
534  std::vector<bool> is_kept(split_points.size() - 1);
535  { // A "split" is a pair of successive split points. Work out for each split
536  // whether we must keep it (we must if it contains at least one frame for
537  // which "nonzero_derivative" == true.)
538  for (size_t s = 0; s < is_kept.size(); s++) {
539  int32 start = split_points[s], end = split_points[s+1];
540  bool keep_this_split = false;
541  for (int32 t = start; t < end; t++)
543  keep_this_split = true;
544  is_kept[s] = keep_this_split;
545  }
546  }
548  egs_out_->clear();
549  egs_out_->reserve(is_kept.size());
551  stats->num_lattices++;
552  stats->longest_lattice = std::max(stats->longest_lattice, num_frames);
553  stats->num_segments += is_kept.size();
554  stats->num_frames_orig += num_frames;
555  for (int32 t = 0; t < num_frames; t++)
557  stats->num_frames_must_keep++;
559  for (size_t s = 0; s < is_kept.size(); s++) {
560  if (is_kept[s]) {
561  stats->num_kept_segments++;
562  OutputOneSplit(split_points[s], split_points[s+1]);
563  int32 segment_len = split_points[s+1] - split_points[s];
564  stats->num_frames_kept_after_split += segment_len;
566  std::max(stats->longest_segment_after_split, segment_len);
567  }
568  }
569 }
574  KALDI_LOG << "Split " << num_lattices << " lattices. Stats:";
575  double kept_segs_per_lat = num_kept_segments * 1.0 / num_lattices,
576  segs_per_lat = num_segments * 1.0 / num_lattices;
578  KALDI_LOG << "Made on average " << segs_per_lat << " segments per lattice, "
579  << "of which " << kept_segs_per_lat << " were kept.";
581  double percent_needed = num_frames_must_keep * 100.0 / num_frames_orig,
582  percent_after_split = num_frames_kept_after_split * 100.0 / num_frames_orig,
583  percent_after_excise = num_frames_kept_after_excise * 100.0 / num_frames_orig;
585  KALDI_LOG << "Needed to keep " << percent_needed << "% of frames, after split "
586  << "kept " << percent_after_split << "%, after excising frames kept "
587  << percent_after_excise << "%.";
589  KALDI_LOG << "Longest lattice had " << longest_lattice
590  << " frames, longest segment after splitting had "
591  << longest_segment_after_split
592  << " frames, longest segment after excising had "
593  << longest_segment_after_excise;
594 }
597  int32 seg_end) {
598  KALDI_ASSERT(seg_begin >= 0 && seg_end > seg_begin && seg_end <= NumFrames());
599  egs_out_->resize(egs_out_->size() + 1);
600  int32 left_context = eg_.left_context, right_context = RightContext(),
601  tot_context = left_context + right_context;
602  DiscriminativeNnetExample &eg_out = egs_out_->back();
603  eg_out.weight = eg_.weight;
605  eg_out.num_ali.insert(eg_out.num_ali.end(),
606  eg_.num_ali.begin() + seg_begin,
607  eg_.num_ali.begin() + seg_end);
609  CreateOutputLattice(seg_begin, seg_end, &(eg_out.den_lat));
611  eg_out.input_frames = eg_.input_frames.Range(seg_begin, seg_end - seg_begin +
612  tot_context,
613  0, eg_.input_frames.NumCols());
615  eg_out.left_context = eg_.left_context;
617  eg_out.spk_info = eg_.spk_info;
619  eg_out.Check();
620 }
622 // static
624  for (StateId s = 0; s < lat->NumStates(); s++) {
625  for (::fst::MutableArcIterator<Lattice> aiter(lat, s); !aiter.Done();
626  aiter.Next()) {
627  Arc arc = aiter.Value();
628  arc.olabel = 0;
629  aiter.SetValue(arc);
630  }
631  }
632 }
636  StateId s, unordered_map<StateId, StateId> *state_map, Lattice *lat_out) {
637  if (state_map->count(s) == 0) {
638  return ((*state_map)[s] = lat_out->AddState());
639  } else {
640  return (*state_map)[s];
641  }
642 }
645  int32 seg_begin, int32 seg_end,
646  CompactLattice *clat_out) {
647  Lattice lat_out;
649  // Below, state_map will map from states in the original lattice
650  // lat_ to ones in the new lattice lat_out.
651  unordered_map<StateId, StateId> state_map;
653  // The range of the loop over s could be made over the
654  // entire lattice, but we limit it for efficiency.
656  for (StateId s = frame_info_[seg_begin].start_state;
657  s <= frame_info_[seg_end].end_state; s++) {
658  int32 t = state_times_[s];
660  if (t < seg_begin || t > seg_end) // state out of range.
661  continue;
663  int32 this_state = GetOutputStateId(s, &state_map, &lat_out);
665  if (t == seg_begin) // note: we only split on frames with just one
666  lat_out.SetStart(this_state); // state, so we reach this only once.
668  if (t == seg_end) { // Make it final and don't process its arcs out.
669  if (seg_end == NumFrames()) {
670  lat_out.SetFinal(this_state, lat_.Final(s));
671  } else {
672  lat_out.SetFinal(this_state, LatticeWeight::One());
673  }
674  continue; // don't process arcs out of this state.
675  }
677  for (fst::ArcIterator<Lattice> aiter(lat_, s); !aiter.Done(); aiter.Next()) {
678  const Arc &arc = aiter.Value();
679  StateId next_state = GetOutputStateId(arc.nextstate,
680  &state_map, &lat_out);
681  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel); // We expect no epsilons.
682  lat_out.AddArc(this_state, Arc(arc.ilabel, arc.olabel, arc.weight,
683  next_state));
684  }
685  }
686  Connect(&lat_out); // this is not really necessary, it's only to make sure
687  // the assert below fails when it should. TODO: remove it.
688  KALDI_ASSERT(lat_out.NumStates() > 0);
689  RemoveAllOutputSymbols(&lat_out);
690  ConvertLattice(lat_out, clat_out);
691 }
693 /*
694 void DiscriminativeExampleSplitter::SelfTest() {
695  bool splits_ok = true; // True iff we split only
696  // on frames where there was
697  // one arc crossing.
699  // we can't do any of this excising frames if we want to
700  // preserve equivalence.
701  std::fill(can_excise_.begin(), can_excise_.end(), false);
703  std::vector<Lattice*> split_lats;
705  int32 cur_t = NumFrames();
706  while (cur_t != 0) {
707  Backtrace this_backtrace = backtrace_[cur_t];
708  int32 prev_t = this_backtrace.prev_frame;
710  int32 seg_begin = prev_t, seg_end = cur_t;
711  Lattice *new_lat = new Lattice();
712  CreateOutputLattice(seg_begin, seg_end, new_lat);
713  split_lats.push_back(new_lat);
715  if (split_penalty_[cur_t] != 0)
716  splits_ok = false; // we split where there was a penalty so we don't
717  // expect equivalence.
718  cur_t = prev_t;
719  }
720  KALDI_ASSERT(!split_lats.empty());
721  std::reverse(split_lats.begin(), split_lats.end());
722  for (size_t i = 1; i < split_lats.size(); i++) {
723  // append split_lats[i] to split_lats[0], putting the
724  // result in split_lats[0].
725  Concat(split_lats[0], *(split_lats[i]));
726  }
727  Connect(split_lats[0]);
728  KALDI_ASSERT(split_lats[0]->NumStates() > 0);
731  if (!splits_ok) {
732  KALDI_LOG << "Not self-testing because we split where there were "
733  << "multiple paths.";
735  } else {
736  if (!(RandEquivalent(*(split_lats[0]), lat_, 5, 0.01,
737  Rand(), 100))) {
738  KALDI_WARN << "Lattices were not equivalent (self-test failed).";
739  KALDI_LOG << "Original lattice was: ";
740  WriteLattice(std::cerr, false, lat_);
741  KALDI_LOG << "New lattice is:";
742  WriteLattice(std::cerr, false, *(split_lats[0]));
743  {
744  Lattice best_path_orig;
745  ShortestPath(lat_, &best_path_orig);
746  KALDI_LOG << "Original best path was:";
747  WriteLattice(std::cerr, false, best_path_orig);
748  }
749  {
750  Lattice best_path_new;
751  ShortestPath(*(split_lats[0]), &best_path_new);
752  KALDI_LOG << "New best path was:";
753  WriteLattice(std::cerr, false, best_path_new);
754  }
755  }
756  }
757  for (size_t i = 0; i < split_lats.size(); i++)
758  delete split_lats[i];
759 }
760 */
765  const SplitDiscriminativeExampleConfig &config,
766  const TransitionModel &tmodel,
767  const DiscriminativeNnetExample &eg,
768  std::vector<DiscriminativeNnetExample> *egs_out,
769  SplitExampleStats *stats_out) {
770  DiscriminativeExampleSplitter splitter(config, tmodel, eg, egs_out);
771  splitter.Split(stats_out);
772 }
776  const SplitDiscriminativeExampleConfig &config,
777  const TransitionModel &tmodel,
778  const DiscriminativeNnetExample &eg,
779  std::vector<DiscriminativeNnetExample> *egs_out,
780  SplitExampleStats *stats_out) {
781  DiscriminativeExampleSplitter splitter(config, tmodel, eg, egs_out);
782  splitter.Excise(stats_out);
783 }
787  const TransitionModel &tmodel,
788  const DiscriminativeNnetExample &eg,
789  std::string criterion,
790  bool drop_frames,
791  bool one_silence_class,
792  Matrix<double> *hash,
793  double *num_weight,
794  double *den_weight,
795  double *tot_t) {
796  int32 feat_dim = eg.input_frames.NumCols(),
797  left_context = eg.left_context,
798  num_frames = eg.num_ali.size(),
799  right_context = eg.input_frames.NumRows() - num_frames - left_context,
800  context_width = left_context + 1 + right_context;
801  *tot_t += num_frames;
802  KALDI_ASSERT(right_context >= 0);
803  KALDI_ASSERT(hash != NULL);
804  if (hash->NumRows() == 0) {
805  hash->Resize(tmodel.NumPdfs(), feat_dim);
806  } else {
807  KALDI_ASSERT(hash->NumRows() == tmodel.NumPdfs() &&
808  hash->NumCols() == feat_dim);
809  }
811  Posterior post;
812  std::vector<int32> silence_phones; // we don't let the user specify this
813  // because it's not necessary for testing
814  // purposes -> leave it empty
815  ExampleToPdfPost(tmodel, silence_phones, criterion, drop_frames,
816  one_silence_class, eg, &post);
818  Vector<BaseFloat> avg_feat(feat_dim);
820  for (int32 t = 0; t < num_frames; t++) {
821  SubMatrix<BaseFloat> context_window(eg.input_frames,
822  t, context_width,
823  0, feat_dim);
824  // set avg_feat to average over the context-window for this frame.
825  avg_feat.AddRowSumMat(1.0 / context_width, context_window, 0.0);
826  Vector<double> avg_feat_dbl(avg_feat);
827  for (size_t i = 0; i < post[t].size(); i++) {
828  int32 pdf_id = post[t][i].first;
829  BaseFloat weight = post[t][i].second;
830  hash->Row(pdf_id).AddVec(weight, avg_feat_dbl);
831  if (weight > 0.0) *num_weight += weight;
832  else *den_weight += -weight;
833  }
834  }
835 }
839  const TransitionModel &tmodel,
840  const std::vector<int32> &silence_phones,
841  std::string criterion,
842  bool drop_frames,
843  bool one_silence_class,
844  const DiscriminativeNnetExample &eg,
845  Posterior *post) {
846  KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr" || criterion == "mmi");
848  Lattice lat;
849  ConvertLattice(eg.den_lat, &lat);
850  TopSort(&lat);
851  if (criterion == "mpfe" || criterion == "smbr") {
852  Posterior tid_post;
853  LatticeForwardBackwardMpeVariants(tmodel, silence_phones, lat, eg.num_ali,
854  criterion, one_silence_class, &tid_post);
856  ConvertPosteriorToPdfs(tmodel, tid_post, post);
857  } else {
858  bool convert_to_pdf_ids = true, cancel = true;
859  LatticeForwardBackwardMmi(tmodel, lat, eg.num_ali,
860  drop_frames, convert_to_pdf_ids, cancel,
861  post);
862  }
863  ScalePosterior(eg.weight, post);
864 }
868  const std::vector<BaseFloat> &costs,
869  std::vector<std::vector<size_t> > *groups) {
870  groups->clear();
871  std::vector<BaseFloat> group_costs;
872  for (size_t i = 0; i < costs.size(); i++) {
873  bool found_group = false;
874  BaseFloat this_cost = costs[i];
875  for (size_t j = 0; j < groups->size(); j++) {
876  if (group_costs[j] + this_cost <= max_cost) {
877  (*groups)[j].push_back(i);
878  group_costs[j] += this_cost;
879  found_group = true;
880  break;
881  }
882  }
883  if (!found_group) { // Put this object in a newly created group.
884  groups->resize(groups->size() + 1);
885  groups->back().push_back(i);
886  group_costs.push_back(this_cost);
887  }
888  }
889 }
892  const std::vector<const DiscriminativeNnetExample*> &input,
893  DiscriminativeNnetExample *output) {
894  KALDI_ASSERT(!input.empty());
895  const DiscriminativeNnetExample &eg0 = *(input[0]);
897  int32 dim = eg0.input_frames.NumCols() + eg0.spk_info.Dim(),
898  left_context = eg0.left_context,
899  num_frames = eg0.num_ali.size(),
900  right_context = eg0.input_frames.NumRows() - num_frames - left_context;
902  int32 tot_frames = eg0.input_frames.NumRows(); // total frames (appended,
903  // with context)
904  for (size_t i = 1; i < input.size(); i++)
905  tot_frames += input[i]->input_frames.NumRows();
907  int32 arbitrary_tid = 1; // arbitrary transition-id that we use to pad the
908  // num_ali and den_lat members between segments
909  // (since they're both the same, and the den-lat in
910  // those parts is linear, they contribute no
911  // derivative to the training).
913  output->den_lat = eg0.den_lat;
914  output->num_ali = eg0.num_ali;
915  output->input_frames.Resize(tot_frames, dim, kUndefined);
916  output->input_frames.Range(0, eg0.input_frames.NumRows(),
917  0, eg0.input_frames.NumCols()).CopyFromMat(eg0.input_frames);
918  if (eg0.spk_info.Dim() != 0) {
919  output->input_frames.Range(0, eg0.input_frames.NumRows(),
920  eg0.input_frames.NumCols(), eg0.spk_info.Dim()).
921  CopyRowsFromVec(eg0.spk_info);
922  }
924  output->num_ali.reserve(tot_frames - left_context - right_context);
925  output->weight = eg0.weight;
926  output->left_context = eg0.left_context;
927  output->spk_info.Resize(0);
929  CompactLattice inter_segment_clat;
930  int32 initial = inter_segment_clat.AddState(); // state 0.
931  inter_segment_clat.SetStart(initial);
933  std::vector<int32> inter_segment_ali(left_context + right_context);
934  std::fill(inter_segment_ali.begin(), inter_segment_ali.end(), arbitrary_tid);
937  final_weight.SetString(inter_segment_ali);
938  inter_segment_clat.SetFinal(initial, final_weight);
940  int32 feat_offset = eg0.input_frames.NumRows();
942  for (size_t i = 1; i < input.size(); i++) {
943  const DiscriminativeNnetExample &eg_i = *(input[i]);
945  output->input_frames.Range(feat_offset, eg_i.input_frames.NumRows(),
946  0, eg_i.input_frames.NumCols()).CopyFromMat(
947  eg_i.input_frames);
948  if (eg_i.spk_info.Dim() != 0) {
949  output->input_frames.Range(feat_offset, eg_i.input_frames.NumRows(),
950  eg_i.input_frames.NumCols(),
951  eg_i.spk_info.Dim()).CopyRowsFromVec(
952  eg_i.spk_info);
954  eg_i.spk_info.Dim() == dim);
955  }
957  output->num_ali.insert(output->num_ali.end(),
958  inter_segment_ali.begin(), inter_segment_ali.end());
959  output->num_ali.insert(output->num_ali.end(),
960  eg_i.num_ali.begin(), eg_i.num_ali.end());
961  Concat(&(output->den_lat), inter_segment_clat);
962  Concat(&(output->den_lat), eg_i.den_lat);
963  KALDI_ASSERT(output->weight == eg_i.weight);
964  KALDI_ASSERT(output->left_context == eg_i.left_context);
965  feat_offset += eg_i.input_frames.NumRows();
966  }
967  KALDI_ASSERT(feat_offset == tot_frames);
968 }
971  int32 max_length,
972  const std::vector<DiscriminativeNnetExample> &input,
973  std::vector<DiscriminativeNnetExample> *output) {
975  std::vector<BaseFloat> costs(input.size());
976  for (size_t i = 0; i < input.size(); i++)
977  costs[i] = static_cast<BaseFloat>(input[i].input_frames.NumRows());
978  std::vector<std::vector<size_t> > groups;
979  SolvePackingProblem(max_length,
980  costs,
981  &groups);
982  output->clear();
983  output->resize(groups.size());
984  for (size_t i = 0; i < groups.size(); i++) {
985  std::vector<const DiscriminativeNnetExample*> group_egs;
986  for (size_t j = 0; j < groups[i].size(); j++) {
987  size_t index = groups[i][j];
988  group_egs.push_back(&(input[index]));
989  }
990  AppendDiscriminativeExamples(group_egs, &((*output)[i]));
991  }
992 }
996 } // namespace nnet2
997 } // namespace kaldi
