DiscriminativeSupervisionSplitter Class Reference

#include <discriminative-supervision.h>

Collaboration diagram for DiscriminativeSupervisionSplitter:

Classes

struct  LatticeInfo
 

Public Types

typedef fst::ArcTpl< LatticeWeightLatticeArc
 
typedef fst::VectorFst< LatticeArcLattice
 

Public Member Functions

 DiscriminativeSupervisionSplitter (const SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const DiscriminativeSupervision &supervision)
 
void GetFrameRange (int32 begin_frame, int32 frames_per_sequence, bool normalize, DiscriminativeSupervision *supervision) const
 
const LatticeDenLat () const
 

Private Member Functions

void CreateRangeLattice (const Lattice &in_lat, const LatticeInfo &scores, int32 begin_frame, int32 end_frame, bool normalize, Lattice *out_lat) const
 
void ComputeLatticeScores (const Lattice &lat, LatticeInfo *scores) const
 
void PrepareLattice (Lattice *lat, LatticeInfo *scores) const
 
void CollapseTransitionIds (const std::vector< int32 > &state_times, Lattice *lat) const
 

Private Attributes

const SplitDiscriminativeSupervisionOptionsconfig_
 
const TransitionModeltmodel_
 
const DiscriminativeSupervisionsupervision_
 
LatticeInfo den_lat_scores_
 
Lattice den_lat_
 

Detailed Description

Definition at line 140 of file discriminative-supervision.h.

Member Typedef Documentation

◆ Lattice

typedef fst::VectorFst<LatticeArc> Lattice

Definition at line 143 of file discriminative-supervision.h.

◆ LatticeArc

typedef fst::ArcTpl<LatticeWeight> LatticeArc

Definition at line 142 of file discriminative-supervision.h.

Constructor & Destructor Documentation

◆ DiscriminativeSupervisionSplitter()

Definition at line 136 of file discriminative-supervision.cc.

References DiscriminativeSupervision::den_lat, DiscriminativeSupervisionSplitter::den_lat_, DiscriminativeSupervisionSplitter::den_lat_scores_, DiscriminativeSupervision::frames_per_sequence, KALDI_ASSERT, KALDI_WARN, DiscriminativeSupervision::num_sequences, DiscriminativeSupervisionSplitter::PrepareLattice(), DiscriminativeSupervisionSplitter::LatticeInfo::state_times, and DiscriminativeSupervisionSplitter::supervision_.

139  :
140  config_(config), tmodel_(tmodel), supervision_(supervision) {
141  if (supervision_.num_sequences != 1) {
142  KALDI_WARN << "Splitting already-reattached sequence (only expected in "
143  << "testing code)";
144  }
145 
146  KALDI_ASSERT(supervision_.num_sequences == 1); // For now, don't allow splitting already merged examples
147 
150 
151  int32 num_states = den_lat_.NumStates(),
153  KALDI_ASSERT(num_states > 0);
154  int32 start_state = den_lat_.Start();
155  // Lattice should be top-sorted and connected, so start-state must be 0.
156  KALDI_ASSERT(start_state == 0 && "Expecting start-state to be 0");
157 
158  KALDI_ASSERT(num_states == den_lat_scores_.state_times.size());
159  KALDI_ASSERT(den_lat_scores_.state_times[start_state] == 0);
160  KALDI_ASSERT(den_lat_scores_.state_times.back() == num_frames);
161 }
const SplitDiscriminativeSupervisionOptions & config_
kaldi::int32 int32
void PrepareLattice(Lattice *lat, LatticeInfo *scores) const
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Function Documentation

◆ CollapseTransitionIds()

void CollapseTransitionIds ( const std::vector< int32 > &  state_times,
Lattice lat 
) const
private

Definition at line 169 of file discriminative-supervision.cc.

References count, KALDI_ASSERT, DiscriminativeSupervisionSplitter::tmodel_, and TransitionModel::TransitionIdToPdf().

Referenced by DiscriminativeSupervisionSplitter::CreateRangeLattice().

170  {
171  typedef Lattice::StateId StateId;
172  typedef Lattice::Arc Arc;
173 
174  int32 num_frames = state_times.back(); // TODO: Check if this is always true
175  StateId num_states = lat->NumStates();
176 
177  std::vector<std::map<int32, int32> > pdf_to_tid(num_frames);
178  for (StateId s = 0; s < num_states; s++) {
179  int32 t = state_times[s];
180  for (fst::MutableArcIterator<Lattice> aiter(lat, s);
181  !aiter.Done(); aiter.Next()) {
182  KALDI_ASSERT(t >= 0 && t < num_frames);
183  Arc arc = aiter.Value();
184  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
185  int32 pdf = tmodel_.TransitionIdToPdf(arc.ilabel);
186  if (pdf_to_tid[t].count(pdf) != 0) {
187  arc.ilabel = arc.olabel = pdf_to_tid[t][pdf];
188  aiter.SetValue(arc);
189  } else {
190  pdf_to_tid[t][pdf] = arc.ilabel;
191  }
192  }
193  }
194 }
fst::StdArc::StateId StateId
Lattice::StateId StateId
kaldi::int32 int32
int32 TransitionIdToPdf(int32 trans_id) const
const size_t count
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ComputeLatticeScores()

void ComputeLatticeScores ( const Lattice lat,
LatticeInfo scores 
) const
private

Definition at line 393 of file discriminative-supervision.cc.

References DiscriminativeSupervisionSplitter::LatticeInfo::alpha, DiscriminativeSupervisionSplitter::LatticeInfo::beta, DiscriminativeSupervisionSplitter::LatticeInfo::Check(), kaldi::ComputeLatticeAlphasAndBetas(), kaldi::LatticeStateTimes(), and DiscriminativeSupervisionSplitter::LatticeInfo::state_times.

Referenced by DiscriminativeSupervisionSplitter::PrepareLattice().

394  {
395  LatticeStateTimes(lat, &(scores->state_times));
396  ComputeLatticeAlphasAndBetas(lat, false,
397  &(scores->alpha), &(scores->beta));
398  scores->Check();
399  // This check will fail if the lattice is not breadth-first search sorted
400 }
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
double ComputeLatticeAlphasAndBetas(const LatticeType &lat, bool viterbi, vector< double > *alpha, vector< double > *beta)

◆ CreateRangeLattice()

void CreateRangeLattice ( const Lattice in_lat,
const LatticeInfo scores,
int32  begin_frame,
int32  end_frame,
bool  normalize,
Lattice out_lat 
) const
private

Definition at line 232 of file discriminative-supervision.cc.

References SplitDiscriminativeSupervisionOptions::acoustic_scale, fst::AcousticLatticeScale(), DiscriminativeSupervisionSplitter::LatticeInfo::alpha, DiscriminativeSupervisionSplitter::LatticeInfo::beta, SplitDiscriminativeSupervisionOptions::collapse_transition_ids, DiscriminativeSupervisionSplitter::CollapseTransitionIds(), DiscriminativeSupervisionSplitter::config_, SplitDiscriminativeSupervisionOptions::determinize, rnnlm::i, KALDI_ASSERT, KALDI_ERR, kaldi::LatticeStateTimes(), SplitDiscriminativeSupervisionOptions::minimize, LatticeWeightTpl< BaseFloat >::One(), fst::ScaleLattice(), LatticeWeightTpl< FloatType >::SetValue1(), LatticeWeightTpl< FloatType >::SetValue2(), DiscriminativeSupervisionSplitter::LatticeInfo::state_times, and kaldi::swap().

Referenced by DiscriminativeSupervisionSplitter::GetFrameRange().

235  {
236  typedef Lattice::StateId StateId;
237 
238  const std::vector<int32> &state_times = scores.state_times;
239 
240  // Some checks to ensure the lattice and scores are prepared properly
241  KALDI_ASSERT(state_times.size() == in_lat.NumStates());
242  if (!in_lat.Properties(fst::kTopSorted, true))
243  KALDI_ERR << "Input lattice must be topologically sorted.";
244 
245  std::vector<int32>::const_iterator begin_iter =
246  std::lower_bound(state_times.begin(), state_times.end(), begin_frame),
247  end_iter = std::lower_bound(begin_iter,
248  state_times.end(), end_frame);
249 
250  KALDI_ASSERT(*begin_iter == begin_frame &&
251  (begin_iter == state_times.begin() ||
252  begin_iter[-1] < begin_frame));
253  // even if end_frame == supervision_.num_frames, there should be a state with
254  // that frame index.
255  KALDI_ASSERT(end_iter[-1] < end_frame &&
256  (end_iter < state_times.end() || *end_iter == end_frame));
257  StateId begin_state = begin_iter - state_times.begin(),
258  end_state = end_iter - state_times.begin();
259 
260  KALDI_ASSERT(end_state > begin_state);
261  out_lat->DeleteStates();
262  out_lat->ReserveStates(end_state - begin_state + 2);
263 
264  // Add special start state
265  StateId start_state = out_lat->AddState();
266  out_lat->SetStart(start_state);
267 
268  for (StateId i = begin_state; i < end_state; i++)
269  out_lat->AddState();
270 
271  // Add the special final-state.
272  StateId final_state = out_lat->AddState();
273  out_lat->SetFinal(final_state, LatticeWeight::One());
274 
275  for (StateId state = begin_state; state < end_state; state++) {
276  StateId output_state = state - begin_state + 1;
277  if (state_times[state] == begin_frame) {
278  // we'd like to make this an initial state, but OpenFst doesn't allow
279  // multiple initial states. Instead we add an epsilon transition to it
280  // from our actual initial state. The weight on this
281  // transition is the forward probability of the said 'initial state'
283  weight.SetValue1((normalize ? scores.beta[0] : 0.0) - scores.alpha[state]);
284  // Add negative of the forward log-probability to the graph cost score,
285  // since the acoustic scores would be changed later.
286  // Assuming that the lattice is scaled with appropriate acoustic
287  // scale.
288  // We additionally normalize using the total lattice score. Since the
289  // same score is added as normalizer to all the paths in the lattice,
290  // the relative probabilities of the paths in the lattice is not affected.
291  // Note: Doing a forward-backward on this split must result in a total
292  // score of 0 because of the normalization.
293 
294  out_lat->AddArc(start_state,
295  LatticeArc(0, 0, weight, output_state));
296  } else {
297  KALDI_ASSERT(scores.state_times[state] < end_frame);
298  }
299  for (fst::ArcIterator<Lattice> aiter(in_lat, state);
300  !aiter.Done(); aiter.Next()) {
301  const LatticeArc &arc = aiter.Value();
302  StateId nextstate = arc.nextstate;
303  if (nextstate >= end_state) {
304  // A transition to any state outside the range becomes a transition to
305  // our special final-state.
306  // The weight is just the negative of the backward log-probability +
307  // the arc cost. We again normalize with the total lattice score.
308  LatticeWeight weight;
309  //KALDI_ASSERT(scores.beta[state] < 0);
310  weight.SetValue1(arc.weight.Value1() - scores.beta[nextstate]);
311  weight.SetValue2(arc.weight.Value2());
312  // Add negative of the backward log-probability to the LM score, since
313  // the acoustic scores would be changed later.
314  // Note: We don't normalize here because that is already done with the
315  // initial cost.
316 
317  out_lat->AddArc(output_state,
318  LatticeArc(arc.ilabel, arc.olabel, weight, final_state));
319  } else {
320  StateId output_nextstate = nextstate - begin_state + 1;
321  out_lat->AddArc(output_state,
322  LatticeArc(arc.ilabel, arc.olabel, arc.weight, output_nextstate));
323  }
324  }
325  }
326 
327  // Get rid of the word labels and put the
328  // transition-ids on both sides.
329  fst::Project(out_lat, fst::PROJECT_INPUT);
330  fst::RmEpsilon(out_lat);
331 
333  CollapseTransitionIds(state_times, out_lat);
334 
335  if (config_.determinize) {
336  if (!config_.minimize) {
337  Lattice tmp_lat;
338  fst::Determinize(*out_lat, &tmp_lat);
339  std::swap(*out_lat, tmp_lat);
340  } else {
341  Lattice tmp_lat;
342  fst::Reverse(*out_lat, &tmp_lat);
343  fst::Determinize(tmp_lat, out_lat);
344  fst::Reverse(*out_lat, &tmp_lat);
345  fst::Determinize(tmp_lat, out_lat);
346  fst::RmEpsilon(out_lat);
347  }
348  }
349 
350  fst::TopSort(out_lat);
351  std::vector<int32> state_times_tmp;
352  KALDI_ASSERT(LatticeStateTimes(*out_lat, &state_times_tmp) ==
353  end_frame - begin_frame);
354 
355  // Remove the acoustic scale that was previously added
356  if (config_.acoustic_scale != 1.0) {
358  1 / config_.acoustic_scale), out_lat);
359  }
360 }
fst::StdArc::StateId StateId
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
static const LatticeWeightTpl One()
const SplitDiscriminativeSupervisionOptions & config_
Lattice::StateId StateId
void CollapseTransitionIds(const std::vector< int32 > &state_times, Lattice *lat) const
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ DenLat()

const Lattice& DenLat ( ) const
inline

◆ GetFrameRange()

void GetFrameRange ( int32  begin_frame,
int32  frames_per_sequence,
bool  normalize,
DiscriminativeSupervision supervision 
) const

Definition at line 206 of file discriminative-supervision.cc.

References DiscriminativeSupervision::Check(), DiscriminativeSupervisionSplitter::CreateRangeLattice(), DiscriminativeSupervision::den_lat, DiscriminativeSupervisionSplitter::den_lat_, DiscriminativeSupervisionSplitter::den_lat_scores_, DiscriminativeSupervision::frames_per_sequence, KALDI_ASSERT, DiscriminativeSupervision::num_ali, DiscriminativeSupervision::num_sequences, DiscriminativeSupervisionSplitter::supervision_, and DiscriminativeSupervision::weight.

Referenced by kaldi::nnet3::ProcessFile().

207  {
208  int32 end_frame = begin_frame + num_frames;
209  // Note: end_frame is not included in the range of frames that the
210  // output supervision object covers; it's one past the end.
211  KALDI_ASSERT(num_frames > 0 && begin_frame >= 0 &&
212  begin_frame + num_frames <=
214 
217  begin_frame, end_frame, normalize,
218  &(out_supervision->den_lat));
219 
220  out_supervision->num_ali.clear();
221  std::copy(supervision_.num_ali.begin() + begin_frame,
222  supervision_.num_ali.begin() + end_frame,
223  std::back_inserter(out_supervision->num_ali));
224 
225  out_supervision->num_sequences = 1;
226  out_supervision->weight = supervision_.weight;
227  out_supervision->frames_per_sequence = num_frames;
228 
229  out_supervision->Check();
230 }
void CreateRangeLattice(const Lattice &in_lat, const LatticeInfo &scores, int32 begin_frame, int32 end_frame, bool normalize, Lattice *out_lat) const
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ PrepareLattice()

void PrepareLattice ( Lattice lat,
LatticeInfo scores 
) const
private

Definition at line 362 of file discriminative-supervision.cc.

References SplitDiscriminativeSupervisionOptions::acoustic_scale, fst::AcousticLatticeScale(), DiscriminativeSupervisionSplitter::ComputeLatticeScores(), DiscriminativeSupervisionSplitter::config_, KALDI_ASSERT, kaldi::LatticeStateTimes(), fst::ScaleLattice(), and DiscriminativeSupervisionSplitter::LatticeInfo::state_times.

Referenced by DiscriminativeSupervisionSplitter::DiscriminativeSupervisionSplitter().

363  {
364  // Scale the lattice to appropriate acoustic scale. It is important to
365  // ensure this is equal to the acoustic scale used while training. This is
366  // because, on splitting lattices, the initial and final costs are added
367  // into the graph cost.
369  if (config_.acoustic_scale != 1.0)
371  config_.acoustic_scale), lat);
372 
373  LatticeStateTimes(*lat, &(scores->state_times));
374  int32 num_states = lat->NumStates();
375  std::vector<std::pair<int32,int32> > state_time_indexes(num_states);
376  for (int32 s = 0; s < num_states; s++) {
377  state_time_indexes[s] = std::make_pair(scores->state_times[s], s);
378  }
379 
380  // Order the states based on the state times. This is stronger than just
381  // topological sort. This is required by the lattice splitting code.
382  std::sort(state_time_indexes.begin(), state_time_indexes.end());
383 
384  std::vector<int32> state_order(num_states);
385  for (int32 s = 0; s < num_states; s++) {
386  state_order[state_time_indexes[s].second] = s;
387  }
388 
389  fst::StateSort(lat, state_order);
390  ComputeLatticeScores(*lat, scores);
391 }
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
void ComputeLatticeScores(const Lattice &lat, LatticeInfo *scores) const
const SplitDiscriminativeSupervisionOptions & config_
kaldi::int32 int32
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Data Documentation

◆ config_

◆ den_lat_

◆ den_lat_scores_

◆ supervision_

◆ tmodel_

const TransitionModel& tmodel_
private

The documentation for this class was generated from the following files: