Public Member Functions | |
NnetDiscriminativeUpdater (const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats) | |
void | Update () |
void | Propagate () |
The forward-through-the-layers part of the computation. More... | |
void | LatticeComputations () |
Does the parts between Propagate() and Backprop(), that involve forward-backward over the lattice. More... | |
void | Backprop () |
double | GetDiscriminativePosteriors (Posterior *post) |
Assuming the lattice already has the correct scores in it, this function does the MPE or MMI forward-backward and puts the resulting discriminative posteriors (which may have positive or negative weight) into "post". More... | |
SubMatrix< BaseFloat > | GetInputFeatures () const |
CuMatrixBase< BaseFloat > & | GetOutput () |
Static Public Member Functions | |
static Int32Pair | MakePair (int32 first, int32 second) |
Private Types | |
typedef LatticeArc | Arc |
typedef Arc::StateId | StateId |
Private Attributes | |
const AmNnet & | am_nnet_ |
const TransitionModel & | tmodel_ |
const NnetDiscriminativeUpdateOptions & | opts_ |
const DiscriminativeNnetExample & | eg_ |
Nnet * | nnet_to_update_ |
NnetDiscriminativeStats * | stats_ |
std::vector< ChunkInfo > | chunk_info_out_ |
std::vector< CuMatrix< BaseFloat > > | forward_data_ |
Lattice | lat_ |
CuMatrix< BaseFloat > | backward_data_ |
std::vector< int32 > | silence_phones_ |
Definition at line 32 of file nnet-compute-discriminative.cc.
|
private |
Definition at line 78 of file nnet-compute-discriminative.cc.
|
private |
Definition at line 79 of file nnet-compute-discriminative.cc.
NnetDiscriminativeUpdater | ( | const AmNnet & | am_nnet, |
const TransitionModel & | tmodel, | ||
const NnetDiscriminativeUpdateOptions & | opts, | ||
const DiscriminativeNnetExample & | eg, | ||
Nnet * | nnet_to_update, | ||
NnetDiscriminativeStats * | stats | ||
) |
Definition at line 101 of file nnet-compute-discriminative.cc.
References NnetDiscriminativeUpdater::am_nnet_, NnetDiscriminativeUpdater::chunk_info_out_, Nnet::ComputeChunkInfo(), NnetDiscriminativeUpdater::eg_, AmNnet::GetNnet(), DiscriminativeNnetExample::input_frames, KALDI_ERR, MatrixBase< Real >::NumRows(), NnetDiscriminativeUpdater::opts_, NnetDiscriminativeUpdater::silence_phones_, NnetDiscriminativeUpdateOptions::silence_phones_str, and kaldi::SplitStringToIntegers().
void Backprop | ( | ) |
Definition at line 349 of file nnet-compute-discriminative.cc.
References NnetDiscriminativeUpdater::am_nnet_, Component::Backprop(), NnetDiscriminativeUpdater::backward_data_, NnetDiscriminativeUpdater::chunk_info_out_, NnetDiscriminativeUpdater::forward_data_, Nnet::GetComponent(), AmNnet::GetNnet(), NnetDiscriminativeUpdater::nnet_to_update_, and Nnet::NumComponents().
Referenced by NnetDiscriminativeUpdater::Update().
double GetDiscriminativePosteriors | ( | Posterior * | post | ) |
Assuming the lattice already has the correct scores in it, this function does the MPE or MMI forward-backward and puts the resulting discriminative posteriors (which may have positive or negative weight) into "post".
It returns, for MPFE/SMBR, the objective function, or for MMI, the negative of the denominator-lattice log-likelihood.
Definition at line 326 of file nnet-compute-discriminative.cc.
References kaldi::ConvertPosteriorToPdfs(), NnetDiscriminativeUpdateOptions::criterion, NnetDiscriminativeUpdateOptions::drop_frames, NnetDiscriminativeUpdater::eg_, KALDI_ASSERT, NnetDiscriminativeUpdater::lat_, kaldi::LatticeForwardBackwardMmi(), kaldi::LatticeForwardBackwardMpeVariants(), DiscriminativeNnetExample::num_ali, NnetDiscriminativeUpdateOptions::one_silence_class, NnetDiscriminativeUpdater::opts_, NnetDiscriminativeUpdater::silence_phones_, and NnetDiscriminativeUpdater::tmodel_.
Referenced by NnetDiscriminativeUpdater::LatticeComputations(), and NnetDiscriminativeUpdater::Update().
Definition at line 121 of file nnet-compute-discriminative.cc.
References NnetDiscriminativeUpdater::am_nnet_, NnetDiscriminativeUpdater::eg_, AmNnet::GetNnet(), DiscriminativeNnetExample::input_frames, KALDI_ASSERT, DiscriminativeNnetExample::left_context, Nnet::LeftContext(), DiscriminativeNnetExample::num_ali, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and Nnet::RightContext().
Referenced by NnetDiscriminativeUpdater::Propagate(), and NnetDiscriminativeUpdater::Update().
|
inline |
Definition at line 68 of file nnet-compute-discriminative.cc.
References NnetDiscriminativeUpdater::forward_data_.
void LatticeComputations | ( | ) |
Does the parts between Propagate() and Backprop(), that involve forward-backward over the lattice.
Definition at line 178 of file nnet-compute-discriminative.cc.
References NnetDiscriminativeUpdateOptions::acoustic_scale, NnetDiscriminativeUpdater::am_nnet_, NnetDiscriminativeUpdater::backward_data_, NnetDiscriminativeUpdateOptions::boost, fst::ConvertLattice(), NnetDiscriminativeUpdateOptions::criterion, DiscriminativeNnetExample::den_lat, VectorBase< Real >::Dim(), NnetDiscriminativeUpdater::eg_, NnetDiscriminativeUpdater::forward_data_, NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), AmNnet::GetNnet(), rnnlm::i, KALDI_ASSERT, KALDI_ISINF, KALDI_ISNAN, KALDI_WARN, NnetDiscriminativeUpdater::lat_, kaldi::LatticeBoost(), kaldi::LatticeStateTimes(), kaldi::Log(), CuMatrixBase< Real >::Lookup(), NnetDiscriminativeUpdater::MakePair(), DiscriminativeNnetExample::num_ali, CuMatrixBase< Real >::NumCols(), Nnet::NumComponents(), CuMatrixBase< Real >::NumRows(), NnetDiscriminativeUpdater::opts_, AmNnet::Priors(), kaldi::ScalePosterior(), NnetDiscriminativeUpdater::silence_phones_, NnetDiscriminativeUpdater::stats_, NnetDiscriminativeUpdater::tmodel_, NnetDiscriminativeStats::tot_den_objf, NnetDiscriminativeStats::tot_num_count, NnetDiscriminativeStats::tot_num_objf, NnetDiscriminativeStats::tot_t, NnetDiscriminativeStats::tot_t_weighted, TransitionModel::TransitionIdToPdf(), DiscriminativeNnetExample::weight, and LatticeWeightTpl< BaseFloat >::Zero().
Referenced by NnetDiscriminativeUpdater::Update().
Definition at line 70 of file nnet-compute-discriminative.cc.
References Int32Pair::first, and Int32Pair::second.
Referenced by NnetDiscriminativeUpdater::LatticeComputations().
void Propagate | ( | ) |
The forward-through-the-layers part of the computation.
Definition at line 142 of file nnet-compute-discriminative.cc.
References NnetDiscriminativeUpdater::am_nnet_, Component::BackpropNeedsInput(), Component::BackpropNeedsOutput(), NnetDiscriminativeUpdater::chunk_info_out_, NnetDiscriminativeUpdater::eg_, NnetDiscriminativeUpdater::forward_data_, Nnet::GetComponent(), NnetDiscriminativeUpdater::GetInputFeatures(), AmNnet::GetNnet(), NnetDiscriminativeUpdater::nnet_to_update_, MatrixBase< Real >::NumCols(), Nnet::NumComponents(), MatrixBase< Real >::NumRows(), Component::Propagate(), and DiscriminativeNnetExample::spk_info.
Referenced by NnetDiscriminativeUpdater::Update().
|
inline |
Definition at line 42 of file nnet-compute-discriminative.cc.
References NnetDiscriminativeUpdater::Backprop(), NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), NnetDiscriminativeUpdater::GetInputFeatures(), NnetDiscriminativeUpdater::LatticeComputations(), NnetDiscriminativeUpdater::nnet_to_update_, and NnetDiscriminativeUpdater::Propagate().
Referenced by kaldi::nnet2::NnetDiscriminativeUpdate().
|
private |
Definition at line 82 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::Backprop(), NnetDiscriminativeUpdater::GetInputFeatures(), NnetDiscriminativeUpdater::LatticeComputations(), NnetDiscriminativeUpdater::NnetDiscriminativeUpdater(), and NnetDiscriminativeUpdater::Propagate().
Definition at line 95 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::Backprop(), and NnetDiscriminativeUpdater::LatticeComputations().
|
private |
Definition at line 90 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::Backprop(), NnetDiscriminativeUpdater::NnetDiscriminativeUpdater(), and NnetDiscriminativeUpdater::Propagate().
|
private |
Definition at line 85 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), NnetDiscriminativeUpdater::GetInputFeatures(), NnetDiscriminativeUpdater::LatticeComputations(), NnetDiscriminativeUpdater::NnetDiscriminativeUpdater(), and NnetDiscriminativeUpdater::Propagate().
Definition at line 93 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::Backprop(), NnetDiscriminativeUpdater::GetOutput(), NnetDiscriminativeUpdater::LatticeComputations(), and NnetDiscriminativeUpdater::Propagate().
|
private |
Definition at line 94 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), and NnetDiscriminativeUpdater::LatticeComputations().
|
private |
Definition at line 86 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::Backprop(), NnetDiscriminativeUpdater::Propagate(), and NnetDiscriminativeUpdater::Update().
|
private |
Definition at line 84 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), NnetDiscriminativeUpdater::LatticeComputations(), and NnetDiscriminativeUpdater::NnetDiscriminativeUpdater().
|
private |
Definition at line 96 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), NnetDiscriminativeUpdater::LatticeComputations(), and NnetDiscriminativeUpdater::NnetDiscriminativeUpdater().
|
private |
Definition at line 89 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::LatticeComputations().
|
private |
Definition at line 83 of file nnet-compute-discriminative.cc.
Referenced by NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), and NnetDiscriminativeUpdater::LatticeComputations().