All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
DiscriminativeExampleSplitter Class Reference

For each frame, judge: More...

Collaboration diagram for DiscriminativeExampleSplitter:

Classes

struct  FrameInfo
 

Public Member Functions

 DiscriminativeExampleSplitter (const SplitDiscriminativeExampleConfig &config, const TransitionModel &tmodel, const DiscriminativeNnetExample &eg, std::vector< DiscriminativeNnetExample > *egs_out)
 
void Excise (SplitExampleStats *stats)
 
void Split (SplitExampleStats *stats)
 

Private Types

typedef LatticeArc Arc
 
typedef Arc::StateId StateId
 
typedef Arc::Label Label
 

Private Member Functions

void PrepareLattice (bool first_time)
 
void CollapseTransitionIds ()
 
bool ComputeFrameInfo ()
 
void OutputOneSplit (int32 seg_begin, int32 seg_end)
 
void DoSplit (SplitExampleStats *stats)
 
void DoExcise (SplitExampleStats *stats)
 
int32 NumFrames () const
 
int32 RightContext ()
 
void CreateOutputLattice (int32 seg_begin, int32 seg_end, CompactLattice *clat_out)
 
StateId GetOutputStateId (StateId s, unordered_map< StateId, StateId > *state_map, Lattice *lat_out)
 

Static Private Member Functions

static void RemoveAllOutputSymbols (Lattice *lat)
 

Private Attributes

const
SplitDiscriminativeExampleConfig
config_
 
const TransitionModeltmodel_
 
const DiscriminativeNnetExampleeg_
 
std::vector
< DiscriminativeNnetExample > * 
egs_out_
 
Lattice lat_
 
std::vector< FrameInfoframe_info_
 
std::vector< int32 > state_times_
 

Detailed Description

For each frame, judge:

  • does it produce a nonzero derivative? [this differs MMI vs MPE]
  • can it be split here [or what is the penalty for splitting here.]
    • depends whether lattice has just one path at that point.

Time taken to process segment of a certain length: [must be sub-linear.] [use quadratic function that's max at specified segment length and zero at zero.]

No penalty for processing frames we don't need to process (already implicit in segment-processing time above.)

Penalty for splitting where we should not split. [Make it propto log(#paths).]

Definition at line 100 of file nnet-example-functions.cc.

Member Typedef Documentation

typedef LatticeArc Arc
private

Definition at line 134 of file nnet-example-functions.cc.

typedef Arc::Label Label
private

Definition at line 136 of file nnet-example-functions.cc.

typedef Arc::StateId StateId
private

Definition at line 135 of file nnet-example-functions.cc.

Constructor & Destructor Documentation

DiscriminativeExampleSplitter ( const SplitDiscriminativeExampleConfig config,
const TransitionModel tmodel,
const DiscriminativeNnetExample eg,
std::vector< DiscriminativeNnetExample > *  egs_out 
)
inline

Definition at line 102 of file nnet-example-functions.cc.

106  :
107  config_(config), tmodel_(tmodel), eg_(eg), egs_out_(egs_out) { }
const SplitDiscriminativeExampleConfig & config_
std::vector< DiscriminativeNnetExample > * egs_out_

Member Function Documentation

void CollapseTransitionIds ( )
private

Definition at line 236 of file nnet-example-functions.cc.

References count, KALDI_ASSERT, DiscriminativeExampleSplitter::lat_, kaldi::LatticeStateTimes(), DiscriminativeExampleSplitter::tmodel_, and TransitionModel::TransitionIdToPdf().

Referenced by DiscriminativeExampleSplitter::PrepareLattice().

236  {
237  std::vector<int32> times;
238  TopSort(&lat_); // Topologically sort the lattice (required by
239  // LatticeStateTimes)
240  int32 num_frames = LatticeStateTimes(lat_, &times);
241  StateId num_states = lat_.NumStates();
242 
243  std::vector<std::map<int32, int32> > pdf_to_tid(num_frames);
244  for (StateId s = 0; s < num_states; s++) {
245  int32 t = times[s];
246  for (fst::MutableArcIterator<Lattice> aiter(&lat_, s);
247  !aiter.Done(); aiter.Next()) {
248  KALDI_ASSERT(t >= 0 && t < num_frames);
249  Arc arc = aiter.Value();
250  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
251  int32 pdf = tmodel_.TransitionIdToPdf(arc.ilabel);
252  if (pdf_to_tid[t].count(pdf) != 0) {
253  arc.ilabel = arc.olabel = pdf_to_tid[t][pdf];
254  aiter.SetValue(arc);
255  } else {
256  pdf_to_tid[t][pdf] = arc.ilabel;
257  }
258  }
259  }
260 }
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
int32 TransitionIdToPdf(int32 trans_id) const
const size_t count
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
bool ComputeFrameInfo ( )
private

Definition at line 299 of file nnet-example-functions.cc.

References DiscriminativeExampleSplitter::FrameInfo::can_excise_frame, DiscriminativeExampleSplitter::config_, SplitDiscriminativeExampleConfig::criterion, DiscriminativeExampleSplitter::FrameInfo::den_pdf_count, DiscriminativeExampleSplitter::FrameInfo::den_state_count, SplitDiscriminativeExampleConfig::drop_frames, DiscriminativeExampleSplitter::eg_, DiscriminativeExampleSplitter::frame_info_, rnnlm::i, KALDI_ASSERT, DiscriminativeExampleSplitter::lat_, kaldi::LatticeStateTimes(), DiscriminativeExampleSplitter::FrameInfo::multiple_transition_ids, DiscriminativeExampleSplitter::FrameInfo::nonzero_derivative, DiscriminativeNnetExample::num_ali, DiscriminativeExampleSplitter::FrameInfo::num_den_overlap, DiscriminativeExampleSplitter::NumFrames(), DiscriminativeExampleSplitter::state_times_, DiscriminativeExampleSplitter::tmodel_, and TransitionModel::TransitionIdToPdf().

Referenced by DiscriminativeExampleSplitter::Excise(), and DiscriminativeExampleSplitter::Split().

299  {
300 
301  int32 num_frames = NumFrames();
302 
303  frame_info_.clear();
304  frame_info_.resize(num_frames + 1);
305 
307 
308  std::vector<std::set<int32> > pdfs_per_frame(num_frames),
309  tids_per_frame(num_frames);
310 
311  int32 num_states = lat_.NumStates();
312 
313  for (int32 state = 0; state < num_states; state++) {
314  int32 t = state_times_[state];
315  KALDI_ASSERT(t >= 0 && t <= num_frames);
316  frame_info_[t].den_state_count++;
317  for (fst::ArcIterator<Lattice> aiter(lat_, state); !aiter.Done();
318  aiter.Next()) {
319  const LatticeArc &arc = aiter.Value();
320  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
321  int32 transition_id = arc.ilabel,
322  pdf_id = tmodel_.TransitionIdToPdf(transition_id);
323  tids_per_frame[t].insert(transition_id);
324  pdfs_per_frame[t].insert(pdf_id);
325  }
326  if (t < num_frames)
327  frame_info_[t+1].start_state = std::min(state,
328  frame_info_[t+1].start_state);
329  frame_info_[t].end_state = std::max(state,
330  frame_info_[t].end_state);
331  }
332 
333  for (int32 i = 1; i <= NumFrames(); i++)
334  frame_info_[i].end_state = std::max(frame_info_[i-1].end_state,
335  frame_info_[i].end_state);
336  for (int32 i = NumFrames() - 1; i >= 0; i--)
337  frame_info_[i].start_state = std::min(frame_info_[i+1].start_state,
338  frame_info_[i].start_state);
339 
340  for (int32 t = 0; t < num_frames; t++) {
341  FrameInfo &frame_info = frame_info_[t];
342  int32 transition_id = eg_.num_ali[t],
343  pdf_id = tmodel_.TransitionIdToPdf(transition_id);
344  frame_info.num_den_overlap = (pdfs_per_frame[t].count(pdf_id) != 0);
345  frame_info.multiple_transition_ids = (tids_per_frame[t].size() > 1);
346  KALDI_ASSERT(!pdfs_per_frame[t].empty());
347  frame_info.den_pdf_count = pdfs_per_frame[t].size();
348 
349  if (config_.criterion == "mpfe" || config_.criterion == "smbr") {
350  frame_info.nonzero_derivative = (frame_info.den_pdf_count > 1);
351  } else {
352  KALDI_ASSERT(config_.criterion == "mmi");
353  if (config_.drop_frames) {
354  // With frame dropping, we'll get nonzero derivative only
355  // if num and den overlap, *and* den has >1 active pdf.
356  frame_info.nonzero_derivative = frame_info.num_den_overlap &&
357  frame_info.den_state_count > 1;
358  } else {
359  // Without frame dropping, we'll get nonzero derivative if num and den
360  // do not overlap , or den has >1 active pdf.
361  frame_info.nonzero_derivative = !frame_info.num_den_overlap ||
362  frame_info.den_state_count > 1;
363  }
364  }
365  // If a frame is part of a segment, but it's not going to contribute
366  // to the derivative and the den lattice has only one pdf active
367  // at that time, then this frame can be excised from the lattice
368  // because it will not affect the posteriors around it.
369  if (config_.criterion == "mpfe") {
370  frame_info.can_excise_frame =
371  !frame_info.nonzero_derivative && \
372  !frame_info.multiple_transition_ids;
373  // in the mpfe case, if there are multiple transition-ids on a
374  // frame there may be multiple phones on a frame, which could
375  // contribute to the objective function even if they share pdf-ids.
376  // (this was an issue that came up during testing).
377  } else {
378  frame_info.can_excise_frame =
379  !frame_info.nonzero_derivative && frame_info.den_pdf_count == 1;
380  }
381  }
382  return true;
383 }
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
const SplitDiscriminativeExampleConfig & config_
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
int32 TransitionIdToPdf(int32 trans_id) const
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void CreateOutputLattice ( int32  seg_begin,
int32  seg_end,
CompactLattice clat_out 
)
private

Definition at line 644 of file nnet-example-functions.cc.

References fst::ConvertLattice(), DiscriminativeExampleSplitter::frame_info_, DiscriminativeExampleSplitter::GetOutputStateId(), KALDI_ASSERT, DiscriminativeExampleSplitter::lat_, DiscriminativeExampleSplitter::NumFrames(), LatticeWeightTpl< BaseFloat >::One(), DiscriminativeExampleSplitter::RemoveAllOutputSymbols(), and DiscriminativeExampleSplitter::state_times_.

Referenced by DiscriminativeExampleSplitter::OutputOneSplit().

646  {
647  Lattice lat_out;
648 
649  // Below, state_map will map from states in the original lattice
650  // lat_ to ones in the new lattice lat_out.
651  unordered_map<StateId, StateId> state_map;
652 
653  // The range of the loop over s could be made over the
654  // entire lattice, but we limit it for efficiency.
655 
656  for (StateId s = frame_info_[seg_begin].start_state;
657  s <= frame_info_[seg_end].end_state; s++) {
658  int32 t = state_times_[s];
659 
660  if (t < seg_begin || t > seg_end) // state out of range.
661  continue;
662 
663  int32 this_state = GetOutputStateId(s, &state_map, &lat_out);
664 
665  if (t == seg_begin) // note: we only split on frames with just one
666  lat_out.SetStart(this_state); // state, so we reach this only once.
667 
668  if (t == seg_end) { // Make it final and don't process its arcs out.
669  if (seg_end == NumFrames()) {
670  lat_out.SetFinal(this_state, lat_.Final(s));
671  } else {
672  lat_out.SetFinal(this_state, LatticeWeight::One());
673  }
674  continue; // don't process arcs out of this state.
675  }
676 
677  for (fst::ArcIterator<Lattice> aiter(lat_, s); !aiter.Done(); aiter.Next()) {
678  const Arc &arc = aiter.Value();
679  StateId next_state = GetOutputStateId(arc.nextstate,
680  &state_map, &lat_out);
681  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel); // We expect no epsilons.
682  lat_out.AddArc(this_state, Arc(arc.ilabel, arc.olabel, arc.weight,
683  next_state));
684  }
685  }
686  Connect(&lat_out); // this is not really necessary, it's only to make sure
687  // the assert below fails when it should. TODO: remove it.
688  KALDI_ASSERT(lat_out.NumStates() > 0);
689  RemoveAllOutputSymbols(&lat_out);
690  ConvertLattice(lat_out, clat_out);
691 }
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
static const LatticeWeightTpl One()
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
StateId GetOutputStateId(StateId s, unordered_map< StateId, StateId > *state_map, Lattice *lat_out)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void DoExcise ( SplitExampleStats stats)
private

Definition at line 398 of file nnet-example-functions.cc.

References DiscriminativeNnetExample::Check(), fst::ConvertLattice(), VectorBase< Real >::CopyFromVec(), DiscriminativeNnetExample::den_lat, DiscriminativeExampleSplitter::eg_, DiscriminativeExampleSplitter::egs_out_, DiscriminativeExampleSplitter::frame_info_, rnnlm::i, DiscriminativeNnetExample::input_frames, KALDI_ASSERT, KALDI_WARN, DiscriminativeExampleSplitter::lat_, DiscriminativeNnetExample::left_context, SplitExampleStats::longest_segment_after_excise, DiscriminativeNnetExample::num_ali, SplitExampleStats::num_frames_kept_after_excise, MatrixBase< Real >::NumCols(), DiscriminativeExampleSplitter::NumFrames(), DiscriminativeExampleSplitter::RemoveAllOutputSymbols(), Matrix< Real >::Resize(), DiscriminativeExampleSplitter::RightContext(), DiscriminativeNnetExample::spk_info, DiscriminativeExampleSplitter::state_times_, and DiscriminativeNnetExample::weight.

Referenced by DiscriminativeExampleSplitter::Excise().

398  {
399  int32 left_context = eg_.left_context,
400  right_context = RightContext(),
401  num_frames = NumFrames();
402  // Compute, for each frame, whether we can excise it.
403  //
404  std::vector<bool> can_excise(num_frames, false);
405 
406  bool need_some_frame = false;
407  for (int32 t = 0; t < num_frames; t++) {
408  can_excise[t] = frame_info_[t].can_excise_frame;
409  if (!can_excise[t])
410  need_some_frame = true;
411  }
412  if (!need_some_frame) { // We don't need any frame within this file, so simply
413  // delete the segment.
414  KALDI_WARN << "Example completely removed when excising."; // unexpected,
415  // as the segment should have been deleted when splitting.
416  egs_out_->clear();
417  return;
418  }
419  egs_out_->resize(1);
420  DiscriminativeNnetExample &eg_out = (*egs_out_)[0];
421 
422  // start_t and end_t will be the central part of the segment, excluding any
423  // frames at the edges that we can excise.
424  int32 start_t, end_t;
425  for (start_t = 0; can_excise[start_t]; start_t++);
426  for (end_t = num_frames; can_excise[end_t-1]; end_t--);
427 
428  // for frames from start_t to end_t-1, do not excise them if
429  // they are within the context-window of a frame that we need to keep.
430  // Note: we do t2 = t - right_context to t + left_context, because we're
431  // concerned whether frame t2 has frame t in its window... it might
432  // seem a bit backwards.
433  std::vector<bool> will_excise(can_excise);
434  for (int32 t = start_t; t < end_t; t++) {
435  for (int32 t2 = t - right_context; t2 <= t + left_context; t2++)
436  if (t2 >= start_t && t2 < end_t && !can_excise[t2])
437  will_excise[t] = false; // can't excise this frame, it's needed for
438  // context.
439  }
440 
441  // Remove all un-needed frames from the lattice by replacing the
442  // symbols with epsilon and then removing the epsilons.
443  // Note, this operation is destructive (it changes lat_).
444  int32 num_states = lat_.NumStates();
445  for (int32 state = 0; state < num_states; state++) {
446  int32 t = state_times_[state];
447  for (::fst::MutableArcIterator<Lattice> aiter(&lat_, state); !aiter.Done();
448  aiter.Next()) {
449  Arc arc = aiter.Value();
450  if (will_excise[t]) {
451  arc.ilabel = arc.olabel = 0;
452  aiter.SetValue(arc);
453  }
454  }
455  }
456  RmEpsilon(&lat_);
458  ConvertLattice(lat_, &eg_out.den_lat);
459 
460  eg_out.num_ali.clear();
461  int32 num_frames_kept = 0;
462  for (int32 t = 0; t < num_frames; t++) {
463  if (!will_excise[t]) {
464  eg_out.num_ali.push_back(eg_.num_ali[t]);
465  num_frames_kept++;
466  }
467  }
468 
469  stats->num_frames_kept_after_excise += num_frames_kept;
470  stats->longest_segment_after_excise = std::max(stats->longest_segment_after_excise,
471  num_frames_kept);
472 
473  int32 num_frames_kept_plus = num_frames_kept + left_context + right_context;
474  eg_out.input_frames.Resize(num_frames_kept_plus,
476 
477  // the left-context of the output will be shifted to the right by
478  // start_t.
479  for (int32 i = 0; i < left_context; i++) {
480  SubVector<BaseFloat> dst(eg_out.input_frames, i);
481  SubVector<BaseFloat> src(eg_.input_frames, start_t + i);
482  dst.CopyFromVec(src);
483  }
484  // the right-context will also be shifted, we take the frames
485  // to the right of end_t.
486  for (int32 i = 0; i < right_context; i++) {
487  SubVector<BaseFloat> dst(eg_out.input_frames,
488  num_frames_kept + left_context + i);
489  SubVector<BaseFloat> src(eg_.input_frames,
490  end_t + left_context + i);
491  dst.CopyFromVec(src);
492  }
493  // now copy the central frames (those that were not excised).
494  int32 dst_t = 0;
495  for (int32 t = start_t; t < end_t; t++) {
496  if (!will_excise[t]) {
497  SubVector<BaseFloat> dst(eg_out.input_frames,
498  left_context + dst_t);
499  SubVector<BaseFloat> src(eg_.input_frames,
500  left_context + t);
501  dst.CopyFromVec(src);
502  dst_t++;
503  }
504  }
505  KALDI_ASSERT(dst_t == num_frames_kept);
506 
507 
508  eg_out.weight = eg_.weight;
509  eg_out.left_context = eg_.left_context;
510  eg_out.spk_info = eg_.spk_info;
511 
512  eg_out.Check();
513 }
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
Vector< BaseFloat > spk_info
spk_info contains any component of the features that varies slowly or not at all with time (and hence...
Definition: nnet-example.h:171
#define KALDI_WARN
Definition: kaldi-error.h:130
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
BaseFloat weight
The weight we assign to this example; this will typically be one, but we include it for the sake of g...
Definition: nnet-example.h:140
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 left_context
The number of frames of left context in the features (we can work out the #frames of right context fr...
Definition: nnet-example.h:164
std::vector< DiscriminativeNnetExample > * egs_out_
void DoSplit ( SplitExampleStats stats)
private

Definition at line 516 of file nnet-example-functions.cc.

References DiscriminativeExampleSplitter::egs_out_, DiscriminativeExampleSplitter::frame_info_, SplitExampleStats::longest_lattice, SplitExampleStats::longest_segment_after_split, SplitExampleStats::num_frames_kept_after_split, SplitExampleStats::num_frames_must_keep, SplitExampleStats::num_frames_orig, SplitExampleStats::num_kept_segments, SplitExampleStats::num_lattices, SplitExampleStats::num_segments, DiscriminativeExampleSplitter::NumFrames(), and DiscriminativeExampleSplitter::OutputOneSplit().

Referenced by DiscriminativeExampleSplitter::Split().

516  {
517  std::vector<int32> split_points;
518  int32 num_frames = NumFrames();
519  {
520  // Make the "split points" 0 and num_frames, and
521  // any frame that has just one state on it and the previous
522  // frame had >1 state. This gives us one split for each
523  // "pinch point" in the lattice. Later we may move each split
524  // to a more optimal location.
525  split_points.push_back(0);
526  for (int32 t = 1; t < num_frames; t++) {
527  if (frame_info_[t].den_state_count == 1 &&
528  frame_info_[t-1].den_state_count > 1)
529  split_points.push_back(t);
530  }
531  split_points.push_back(num_frames);
532  }
533 
534  std::vector<bool> is_kept(split_points.size() - 1);
535  { // A "split" is a pair of successive split points. Work out for each split
536  // whether we must keep it (we must if it contains at least one frame for
537  // which "nonzero_derivative" == true.)
538  for (size_t s = 0; s < is_kept.size(); s++) {
539  int32 start = split_points[s], end = split_points[s+1];
540  bool keep_this_split = false;
541  for (int32 t = start; t < end; t++)
542  if (frame_info_[t].nonzero_derivative)
543  keep_this_split = true;
544  is_kept[s] = keep_this_split;
545  }
546  }
547 
548  egs_out_->clear();
549  egs_out_->reserve(is_kept.size());
550 
551  stats->num_lattices++;
552  stats->longest_lattice = std::max(stats->longest_lattice, num_frames);
553  stats->num_segments += is_kept.size();
554  stats->num_frames_orig += num_frames;
555  for (int32 t = 0; t < num_frames; t++)
556  if (frame_info_[t].nonzero_derivative)
557  stats->num_frames_must_keep++;
558 
559  for (size_t s = 0; s < is_kept.size(); s++) {
560  if (is_kept[s]) {
561  stats->num_kept_segments++;
562  OutputOneSplit(split_points[s], split_points[s+1]);
563  int32 segment_len = split_points[s+1] - split_points[s];
564  stats->num_frames_kept_after_split += segment_len;
565  stats->longest_segment_after_split =
566  std::max(stats->longest_segment_after_split, segment_len);
567  }
568  }
569 }
void OutputOneSplit(int32 seg_begin, int32 seg_end)
std::vector< DiscriminativeNnetExample > * egs_out_
void Excise ( SplitExampleStats stats)
inline

Definition at line 109 of file nnet-example-functions.cc.

References DiscriminativeNnetExample::Check(), DiscriminativeExampleSplitter::ComputeFrameInfo(), DiscriminativeExampleSplitter::config_, DiscriminativeExampleSplitter::DoExcise(), DiscriminativeExampleSplitter::eg_, DiscriminativeExampleSplitter::egs_out_, SplitDiscriminativeExampleConfig::excise, and DiscriminativeExampleSplitter::PrepareLattice().

Referenced by kaldi::nnet2::ExciseDiscriminativeExample().

109  {
110  eg_.Check();
111  PrepareLattice(false);
113  if (!config_.excise) {
114  egs_out_->resize(1);
115  (*egs_out_)[0] = eg_;
116  } else {
117  DoExcise(stats);
118  }
119  }
const SplitDiscriminativeExampleConfig & config_
std::vector< DiscriminativeNnetExample > * egs_out_
DiscriminativeExampleSplitter::StateId GetOutputStateId ( StateId  s,
unordered_map< StateId, StateId > *  state_map,
Lattice lat_out 
)
private

Definition at line 635 of file nnet-example-functions.cc.

Referenced by DiscriminativeExampleSplitter::CreateOutputLattice().

636  {
637  if (state_map->count(s) == 0) {
638  return ((*state_map)[s] = lat_out->AddState());
639  } else {
640  return (*state_map)[s];
641  }
642 }
void OutputOneSplit ( int32  seg_begin,
int32  seg_end 
)
private

Definition at line 596 of file nnet-example-functions.cc.

References DiscriminativeNnetExample::Check(), DiscriminativeExampleSplitter::CreateOutputLattice(), DiscriminativeNnetExample::den_lat, DiscriminativeExampleSplitter::eg_, DiscriminativeExampleSplitter::egs_out_, DiscriminativeNnetExample::input_frames, KALDI_ASSERT, DiscriminativeNnetExample::left_context, DiscriminativeNnetExample::num_ali, MatrixBase< Real >::NumCols(), DiscriminativeExampleSplitter::NumFrames(), MatrixBase< Real >::Range(), DiscriminativeExampleSplitter::RightContext(), DiscriminativeNnetExample::spk_info, and DiscriminativeNnetExample::weight.

Referenced by DiscriminativeExampleSplitter::DoSplit().

597  {
598  KALDI_ASSERT(seg_begin >= 0 && seg_end > seg_begin && seg_end <= NumFrames());
599  egs_out_->resize(egs_out_->size() + 1);
600  int32 left_context = eg_.left_context, right_context = RightContext(),
601  tot_context = left_context + right_context;
602  DiscriminativeNnetExample &eg_out = egs_out_->back();
603  eg_out.weight = eg_.weight;
604 
605  eg_out.num_ali.insert(eg_out.num_ali.end(),
606  eg_.num_ali.begin() + seg_begin,
607  eg_.num_ali.begin() + seg_end);
608 
609  CreateOutputLattice(seg_begin, seg_end, &(eg_out.den_lat));
610 
611  eg_out.input_frames = eg_.input_frames.Range(seg_begin, seg_end - seg_begin +
612  tot_context,
613  0, eg_.input_frames.NumCols());
614 
615  eg_out.left_context = eg_.left_context;
616 
617  eg_out.spk_info = eg_.spk_info;
618 
619  eg_out.Check();
620 }
Vector< BaseFloat > spk_info
spk_info contains any component of the features that varies slowly or not at all with time (and hence...
Definition: nnet-example.h:171
SubMatrix< Real > Range(const MatrixIndexT row_offset, const MatrixIndexT num_rows, const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Return a sub-part of matrix.
Definition: kaldi-matrix.h:196
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
BaseFloat weight
The weight we assign to this example; this will typically be one, but we include it for the sake of g...
Definition: nnet-example.h:140
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 left_context
The number of frames of left context in the features (we can work out the #frames of right context fr...
Definition: nnet-example.h:164
void CreateOutputLattice(int32 seg_begin, int32 seg_end, CompactLattice *clat_out)
std::vector< DiscriminativeNnetExample > * egs_out_
void PrepareLattice ( bool  first_time)
private

Definition at line 263 of file nnet-example-functions.cc.

References SplitDiscriminativeExampleConfig::collapse_transition_ids, DiscriminativeExampleSplitter::CollapseTransitionIds(), DiscriminativeExampleSplitter::config_, fst::ConvertLattice(), SplitDiscriminativeExampleConfig::criterion, DiscriminativeNnetExample::den_lat, SplitDiscriminativeExampleConfig::determinize, DiscriminativeExampleSplitter::eg_, DiscriminativeExampleSplitter::lat_, and SplitDiscriminativeExampleConfig::minimize.

Referenced by DiscriminativeExampleSplitter::Excise(), and DiscriminativeExampleSplitter::Split().

263  {
265 
266  Project(&lat_, fst::PROJECT_INPUT); // Get rid of the word labels and put the
267  // transition-ids on both sides.
268 
269  RmEpsilon(&lat_); // Remove epsilons.. this simplifies
270  // certain things.
271 
272  if (first_time) {
275 
276  if (config_.determinize) {
277  if (!config_.minimize) {
278  Lattice det_lat;
279  Determinize(lat_, &det_lat);
280  lat_ = det_lat;
281  } else {
282  Lattice tmp_lat;
283  Reverse(lat_, &tmp_lat);
284  Determinize(tmp_lat, &lat_);
285  Reverse(lat_, &tmp_lat);
286  Determinize(tmp_lat, &lat_);
287  RmEpsilon(&lat_);
288  // Previously we determinized, then did
289  // Minimize(&lat_);
290  // but this was too slow.
291  }
292  }
293  }
294  TopSort(&lat_); // Topologically sort the lattice.
295 }
const SplitDiscriminativeExampleConfig & config_
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
CompactLattice den_lat
The denominator lattice.
Definition: nnet-example.h:148
void RemoveAllOutputSymbols ( Lattice lat)
staticprivate

Definition at line 623 of file nnet-example-functions.cc.

Referenced by DiscriminativeExampleSplitter::CreateOutputLattice(), and DiscriminativeExampleSplitter::DoExcise().

623  {
624  for (StateId s = 0; s < lat->NumStates(); s++) {
625  for (::fst::MutableArcIterator<Lattice> aiter(lat, s); !aiter.Done();
626  aiter.Next()) {
627  Arc arc = aiter.Value();
628  arc.olabel = 0;
629  aiter.SetValue(arc);
630  }
631  }
632 }
int32 RightContext ( )
inlineprivate

Definition at line 161 of file nnet-example-functions.cc.

References DiscriminativeExampleSplitter::eg_, DiscriminativeNnetExample::input_frames, DiscriminativeNnetExample::left_context, DiscriminativeExampleSplitter::NumFrames(), and MatrixBase< Real >::NumRows().

Referenced by DiscriminativeExampleSplitter::DoExcise(), and DiscriminativeExampleSplitter::OutputOneSplit().

161 { return eg_.input_frames.NumRows() - NumFrames() - eg_.left_context; }
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
int32 left_context
The number of frames of left context in the features (we can work out the #frames of right context fr...
Definition: nnet-example.h:164
void Split ( SplitExampleStats stats)
inline

Definition at line 121 of file nnet-example-functions.cc.

References DiscriminativeNnetExample::Check(), DiscriminativeExampleSplitter::ComputeFrameInfo(), DiscriminativeExampleSplitter::config_, DiscriminativeExampleSplitter::DoSplit(), DiscriminativeExampleSplitter::eg_, DiscriminativeExampleSplitter::egs_out_, DiscriminativeExampleSplitter::PrepareLattice(), and SplitDiscriminativeExampleConfig::split.

Referenced by kaldi::nnet2::SplitDiscriminativeExample().

121  {
122  if (!config_.split) {
123  egs_out_->resize(1);
124  (*egs_out_)[0] = eg_;
125  } else {
126  eg_.Check();
127  PrepareLattice(true);
129  DoSplit(stats);
130  }
131  }
const SplitDiscriminativeExampleConfig & config_
std::vector< DiscriminativeNnetExample > * egs_out_

Member Data Documentation


The documentation for this class was generated from the following file: