This class is for single-threaded training of neural nets using the 'chain' model. More...
#include <nnet-chain-training.h>
Public Member Functions | |
NnetChainTrainer (const NnetChainTrainingOptions &config, const fst::StdVectorFst &den_fst, Nnet *nnet) | |
void | Train (const NnetChainExample &eg) |
bool | PrintTotalStats () const |
~NnetChainTrainer () | |
Private Member Functions | |
void | TrainInternal (const NnetChainExample &eg, const NnetComputation &computation) |
void | TrainInternalBackstitch (const NnetChainExample &eg, const NnetComputation &computation, bool is_backstitch_step1) |
void | ProcessOutputs (bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer) |
Private Attributes | |
const NnetChainTrainingOptions | opts_ |
chain::DenominatorGraph | den_graph_ |
Nnet * | nnet_ |
Nnet * | delta_nnet_ |
CachingOptimizingCompiler | compiler_ |
int32 | num_minibatches_processed_ |
MaxChangeStats | max_change_stats_ |
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > | objf_info_ |
int32 | srand_seed_ |
This class is for single-threaded training of neural nets using the 'chain' model.
Definition at line 55 of file nnet-chain-training.h.
NnetChainTrainer | ( | const NnetChainTrainingOptions & | config, |
const fst::StdVectorFst & | den_fst, | ||
Nnet * | nnet | ||
) |
Definition at line 27 of file nnet-chain-training.cc.
References NnetTrainerOptions::backstitch_training_interval, NnetChainTrainer::compiler_, Nnet::Copy(), NnetChainTrainer::delta_nnet_, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetTrainerOptions::read_cache, CachingOptimizingCompiler::ReadCache(), kaldi::nnet3::ScaleNnet(), Input::Stream(), NnetTrainerOptions::zero_component_stats, and kaldi::nnet3::ZeroComponentStats().
~NnetChainTrainer | ( | ) |
Definition at line 287 of file nnet-chain-training.cc.
References NnetTrainerOptions::binary_write_cache, NnetChainTrainer::compiler_, NnetChainTrainer::delta_nnet_, KALDI_LOG, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::opts_, Output::Stream(), NnetTrainerOptions::write_cache, and CachingOptimizingCompiler::WriteCache().
bool PrintTotalStats | ( | ) | const |
Definition at line 273 of file nnet-chain-training.cc.
References NnetChainTrainer::max_change_stats_, NnetChainTrainer::nnet_, NnetChainTrainer::objf_info_, MaxChangeStats::Print(), and ObjectiveFunctionInfo::PrintTotalStats().
|
private |
Definition at line 204 of file nnet-chain-training.cc.
References NnetComputer::AcceptInput(), NnetChainTrainingOptions::apply_deriv_weights, NnetChainTrainingOptions::chain_config, NnetChainTrainer::den_graph_, NnetChainSupervision::deriv_weights, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, CuMatrixBase< Real >::MulRowsVec(), NnetChainSupervision::name, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::num_minibatches_processed_, NVTX_RANGE, NnetChainTrainer::objf_info_, NnetChainTrainer::opts_, NnetChainExample::outputs, NnetTrainerOptions::print_interval, CuMatrixBase< Real >::Scale(), NnetChainSupervision::supervision, and kaldi::TraceMatMat().
Referenced by NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
void Train | ( | const NnetChainExample & | eg | ) |
Definition at line 60 of file nnet-chain-training.cc.
References NnetTrainerOptions::backstitch_training_interval, NnetTrainerOptions::backstitch_training_scale, NnetChainTrainingOptions::chain_config, CachingOptimizingCompiler::Compile(), NnetChainTrainer::compiler_, kaldi::nnet3::ConsolidateMemory(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::FreezeNaturalGradient(), kaldi::nnet3::GetChainComputationRequest(), KALDI_ASSERT, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::num_minibatches_processed_, NVTX_RANGE, NnetChainTrainer::opts_, kaldi::nnet3::ResetGenerators(), NnetChainTrainer::srand_seed_, NnetTrainerOptions::store_component_stats, NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
|
private |
Definition at line 97 of file nnet-chain-training.cc.
References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, kaldi::nnet3::ConstrainOrthonormal(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetChainExample::inputs, NnetTrainerOptions::l2_regularize_factor, NnetChainTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NVTX_RANGE, NnetChainTrainer::opts_, NnetChainTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().
Referenced by NnetChainTrainer::Train().
|
private |
Definition at line 143 of file nnet-chain-training.cc.
References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::backstitch_training_scale, NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, kaldi::nnet3::ConstrainOrthonormal(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetChainExample::inputs, NnetTrainerOptions::l2_regularize_factor, NnetChainTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::opts_, NnetChainTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().
Referenced by NnetChainTrainer::Train().
|
private |
Definition at line 89 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::NnetChainTrainer(), NnetChainTrainer::Train(), and NnetChainTrainer::~NnetChainTrainer().
|
private |
Definition at line 87 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::NnetChainTrainer(), NnetChainTrainer::Train(), NnetChainTrainer::TrainInternal(), NnetChainTrainer::TrainInternalBackstitch(), and NnetChainTrainer::~NnetChainTrainer().
|
private |
Definition at line 85 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::ProcessOutputs().
|
private |
Definition at line 97 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::PrintTotalStats(), NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
|
private |
Definition at line 86 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::NnetChainTrainer(), NnetChainTrainer::PrintTotalStats(), NnetChainTrainer::ProcessOutputs(), NnetChainTrainer::Train(), NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
|
private |
Definition at line 94 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::ProcessOutputs(), and NnetChainTrainer::Train().
|
private |
Definition at line 99 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::PrintTotalStats(), and NnetChainTrainer::ProcessOutputs().
|
private |
Definition at line 83 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::ProcessOutputs(), NnetChainTrainer::Train(), NnetChainTrainer::TrainInternal(), NnetChainTrainer::TrainInternalBackstitch(), and NnetChainTrainer::~NnetChainTrainer().
|
private |
Definition at line 104 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::Train().