kaldi::discriminative Namespace Reference

Classes

class  DiscriminativeComputation
 
struct  DiscriminativeObjectiveInfo
 
struct  DiscriminativeOptions
 
struct  DiscriminativeSupervision
 
class  DiscriminativeSupervisionSplitter
 
struct  SplitDiscriminativeSupervisionOptions
 

Functions

void MergeSupervision (const std::vector< const DiscriminativeSupervision * > &input, DiscriminativeSupervision *output_supervision)
 This function appends a list of supervision objects to create what will usually be a single such object, but if the weights and num-frames are not all the same it will only append Supervision objects where successive ones have the same weight and num-frames, and if 'compactify' is true. More...
 
void ComputeDiscriminativeObjfAndDeriv (const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
 This function does forward-backward on the numerator and denominator lattices and computes derivates wrt to the output for the specified objective function. More...
 

Function Documentation

◆ ComputeDiscriminativeObjfAndDeriv()

void ComputeDiscriminativeObjfAndDeriv ( const DiscriminativeOptions opts,
const TransitionModel tmodel,
const CuVectorBase< BaseFloat > &  log_priors,
const DiscriminativeSupervision supervision,
const CuMatrixBase< BaseFloat > &  nnet_output,
DiscriminativeObjectiveInfo stats,
CuMatrixBase< BaseFloat > *  nnet_output_deriv,
CuMatrixBase< BaseFloat > *  xent_output_deriv 
)

This function does forward-backward on the numerator and denominator lattices and computes derivates wrt to the output for the specified objective function.

Parameters
[in]optsStruct containing options
[in]tmodelTransition model
[in]log_priorsVector of log-priors for pdfs
[in]supervisionThe supervision object, containing the numerator and denominator paths. The denominator is always a lattice. The numerator is an alignment.
[in]nnet_outputThe output of the neural net; dimension must equal ((supervision.num_sequences * supervision.frames_per_sequence) by tmodel.NumPdfs()).
[out]statsStatistics accumulated during training such as the objective function and the total weight.
[out]xent_output_derivIf non-NULL, then the xent objective derivative (which equals a posterior from the numerator forward-backward, scaled by the supervision weight) is written to here. This will be used in the cross-entropy regularization code.

Definition at line 546 of file discriminative-training.cc.

References DiscriminativeComputation::Compute().

Referenced by DiscriminativeObjectiveInfo::AccumulateOutput(), NnetDiscriminativeComputeObjf::ProcessOutputs(), and NnetDiscriminativeTrainer::ProcessOutputs().

553  {
554  DiscriminativeComputation computation(opts, tmodel, log_priors, supervision,
555  nnet_output, stats,
556  nnet_output_deriv, xent_output_deriv);
557  computation.Compute();
558 }

◆ MergeSupervision()

void MergeSupervision ( const std::vector< const DiscriminativeSupervision * > &  input,
DiscriminativeSupervision output_supervision 
)

This function appends a list of supervision objects to create what will usually be a single such object, but if the weights and num-frames are not all the same it will only append Supervision objects where successive ones have the same weight and num-frames, and if 'compactify' is true.

The normal use-case for this is when you are combining neural-net examples for training; appending them like this helps to simplify the training process.

Definition at line 402 of file discriminative-supervision.cc.

References DiscriminativeSupervision::Check(), DiscriminativeSupervision::den_lat, DiscriminativeSupervision::frames_per_sequence, rnnlm::i, KALDI_ASSERT, KALDI_ERR, DiscriminativeSupervision::num_ali, DiscriminativeSupervision::num_sequences, and DiscriminativeSupervision::weight.

Referenced by kaldi::nnet3::MergeSupervision().

403  {
404  KALDI_ASSERT(!input.empty());
405  int32 num_inputs = input.size();
406  if (num_inputs == 1) {
407  *output_supervision = *(input[0]);
408  return;
409  }
410  *output_supervision = *(input[num_inputs-1]);
411  for (int32 i = num_inputs - 2; i >= 0; i--) {
412  const DiscriminativeSupervision &src = *(input[i]);
413  KALDI_ASSERT(src.num_sequences == 1);
414  if (output_supervision->weight == src.weight &&
415  output_supervision->frames_per_sequence ==
416  src.frames_per_sequence) {
417  // Combine with current output
418  // append src.den_lat to output_supervision->den_lat.
419  fst::Concat(src.den_lat, &output_supervision->den_lat);
420 
421  output_supervision->num_ali.insert(
422  output_supervision->num_ali.begin(),
423  src.num_ali.begin(), src.num_ali.end());
424 
425  output_supervision->num_sequences++;
426  } else {
427  KALDI_ERR << "Mismatch weight or frames_per_sequence between inputs";
428  }
429  }
430  DiscriminativeSupervision &out_sup = *output_supervision;
431  fst::TopSort(&(out_sup.den_lat));
432  out_sup.Check();
433 }
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185