All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
kaldi::discriminative Namespace Reference

Classes

class  DiscriminativeComputation
 
struct  DiscriminativeObjectiveInfo
 
struct  DiscriminativeOptions
 
struct  DiscriminativeSupervision
 
struct  DiscriminativeSupervisionOptions
 
class  DiscriminativeSupervisionSplitter
 
struct  SplitDiscriminativeSupervisionOptions
 

Typedefs

typedef TableWriter
< KaldiObjectHolder
< DiscriminativeSupervision > > 
DiscriminativeSupervisionWriter
 
typedef SequentialTableReader
< KaldiObjectHolder
< DiscriminativeSupervision > > 
SequentialDiscriminativeSupervisionReader
 
typedef
RandomAccessTableReader
< KaldiObjectHolder
< DiscriminativeSupervision > > 
RandomAccessDiscriminativeSupervisionReader
 

Functions

void AppendSupervision (const std::vector< const DiscriminativeSupervision * > &input, bool compactify, std::vector< DiscriminativeSupervision > *output_supervision)
 This function appends a list of supervision objects to create what will usually be a single such object, but if the weights and num-frames are not all the same it will only append Supervision objects where successive ones have the same weight and num-frames, and if 'compactify' is true. More...
 
void ComputeDiscriminativeObjfAndDeriv (const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
 This function does forward-backward on the numerator and denominator lattices and computes derivates wrt to the output for the specified objective function. More...
 

Typedef Documentation

Function Documentation

void AppendSupervision ( const std::vector< const DiscriminativeSupervision * > &  input,
bool  compactify,
std::vector< DiscriminativeSupervision > *  output_supervision 
)

This function appends a list of supervision objects to create what will usually be a single such object, but if the weights and num-frames are not all the same it will only append Supervision objects where successive ones have the same weight and num-frames, and if 'compactify' is true.

The normal use-case for this is when you are combining neural-net examples for training; appending them like this helps to simplify the training process.

Definition at line 405 of file discriminative-supervision.cc.

References DiscriminativeSupervision::Check(), DiscriminativeSupervision::den_lat, DiscriminativeSupervision::frames_per_sequence, rnnlm::i, KALDI_ASSERT, DiscriminativeSupervision::num_ali, DiscriminativeSupervision::num_sequences, and DiscriminativeSupervision::weight.

Referenced by kaldi::nnet3::MergeSupervision().

407  {
408  KALDI_ASSERT(!input.empty());
409  int32 num_inputs = input.size();
410  if (num_inputs == 1) {
411  output_supervision->resize(1);
412  (*output_supervision)[0] = *(input[0]);
413  return;
414  }
415  std::vector<bool> output_was_merged;
416  output_supervision->clear();
417  output_supervision->reserve(input.size());
418  for (int32 i = 0; i < input.size(); i++) {
419  const DiscriminativeSupervision &src = *(input[i]);
420  KALDI_ASSERT(src.num_sequences == 1);
421  if (compactify && !output_supervision->empty() &&
422  output_supervision->back().weight == src.weight &&
423  output_supervision->back().frames_per_sequence ==
424  src.frames_per_sequence) {
425  // Combine with current output
426  // append src.den_lat to output_supervision->den_lat.
427  fst::Concat(&output_supervision->back().den_lat, src.den_lat);
428 
429  output_supervision->back().num_ali.insert(
430  output_supervision->back().num_ali.end(),
431  src.num_ali.begin(), src.num_ali.end());
432 
433  output_supervision->back().num_sequences++;
434  output_was_merged.back() = true;
435  } else {
436  output_supervision->resize(output_supervision->size() + 1);
437  output_supervision->back() = src;
438  output_was_merged.push_back(false);
439  }
440  }
441  KALDI_ASSERT(output_was_merged.size() == output_supervision->size());
442  for (size_t i = 0; i < output_supervision->size(); i++) {
443  if (output_was_merged[i]) {
444  DiscriminativeSupervision &out_sup = (*output_supervision)[i];
445  fst::TopSort(&(out_sup.den_lat));
446  out_sup.Check();
447  }
448  }
449 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void ComputeDiscriminativeObjfAndDeriv ( const DiscriminativeOptions &  opts,
const TransitionModel &  tmodel,
const CuVectorBase< BaseFloat > &  log_priors,
const DiscriminativeSupervision &  supervision,
const CuMatrixBase< BaseFloat > &  nnet_output,
DiscriminativeObjectiveInfo *  stats,
CuMatrixBase< BaseFloat > *  nnet_output_deriv,
CuMatrixBase< BaseFloat > *  xent_output_deriv 
)

This function does forward-backward on the numerator and denominator lattices and computes derivates wrt to the output for the specified objective function.

Parameters
[in]optsStruct containing options
[in]tmodelTransition model
[in]log_priorsVector of log-priors for pdfs
[in]supervisionThe supervision object, containing the numerator and denominator paths. The denominator is always a lattice. The numerator is an alignment.
[in]nnet_outputThe output of the neural net; dimension must equal ((supervision.num_sequences * supervision.frames_per_sequence) by tmodel.NumPdfs()).
[out]statsStatistics accumulated during training such as the objective function and the total weight.
[out]xent_output_derivIf non-NULL, then the xent objective derivative (which equals a posterior from the numerator forward-backward, scaled by the supervision weight) is written to here. This will be used in the cross-entropy regularization code.

Definition at line 546 of file discriminative-training.cc.

References DiscriminativeComputation::Compute().

Referenced by NnetDiscriminativeComputeObjf::ProcessOutputs(), and NnetDiscriminativeTrainer::ProcessOutputs().

553  {
554  DiscriminativeComputation computation(opts, tmodel, log_priors, supervision,
555  nnet_output, stats,
556  nnet_output_deriv, xent_output_deriv);
557  computation.Compute();
558 }