This class is for single-threaded training of neural nets using the 'chain' model. More...
#include <nnet-chain-training.h>

Public Member Functions | |
| NnetChainTrainer (const NnetChainTrainingOptions &config, const fst::StdVectorFst &den_fst, Nnet *nnet) | |
| void | Train (const NnetChainExample &eg) | 
| bool | PrintTotalStats () const | 
| ~NnetChainTrainer () | |
Private Member Functions | |
| void | TrainInternal (const NnetChainExample &eg, const NnetComputation &computation) | 
| void | TrainInternalBackstitch (const NnetChainExample &eg, const NnetComputation &computation, bool is_backstitch_step1) | 
| void | ProcessOutputs (bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer) | 
Private Attributes | |
| const NnetChainTrainingOptions | opts_ | 
| chain::DenominatorGraph | den_graph_ | 
| Nnet * | nnet_ | 
| Nnet * | delta_nnet_ | 
| CachingOptimizingCompiler | compiler_ | 
| int32 | num_minibatches_processed_ | 
| MaxChangeStats | max_change_stats_ | 
| unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > | objf_info_ | 
| int32 | srand_seed_ | 
This class is for single-threaded training of neural nets using the 'chain' model.
Definition at line 55 of file nnet-chain-training.h.
| NnetChainTrainer | ( | const NnetChainTrainingOptions & | config, | 
| const fst::StdVectorFst & | den_fst, | ||
| Nnet * | nnet | ||
| ) | 
Definition at line 27 of file nnet-chain-training.cc.
References NnetTrainerOptions::backstitch_training_interval, NnetChainTrainer::compiler_, Nnet::Copy(), NnetChainTrainer::delta_nnet_, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetTrainerOptions::read_cache, CachingOptimizingCompiler::ReadCache(), kaldi::nnet3::ScaleNnet(), Input::Stream(), NnetTrainerOptions::zero_component_stats, and kaldi::nnet3::ZeroComponentStats().
| ~NnetChainTrainer | ( | ) | 
Definition at line 287 of file nnet-chain-training.cc.
References NnetTrainerOptions::binary_write_cache, NnetChainTrainer::compiler_, NnetChainTrainer::delta_nnet_, KALDI_LOG, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::opts_, Output::Stream(), NnetTrainerOptions::write_cache, and CachingOptimizingCompiler::WriteCache().
| bool PrintTotalStats | ( | ) | const | 
Definition at line 273 of file nnet-chain-training.cc.
References NnetChainTrainer::max_change_stats_, NnetChainTrainer::nnet_, NnetChainTrainer::objf_info_, MaxChangeStats::Print(), and ObjectiveFunctionInfo::PrintTotalStats().
      
  | 
  private | 
Definition at line 204 of file nnet-chain-training.cc.
References NnetComputer::AcceptInput(), NnetChainTrainingOptions::apply_deriv_weights, NnetChainTrainingOptions::chain_config, NnetChainTrainer::den_graph_, NnetChainSupervision::deriv_weights, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, CuMatrixBase< Real >::MulRowsVec(), NnetChainSupervision::name, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::num_minibatches_processed_, NVTX_RANGE, NnetChainTrainer::objf_info_, NnetChainTrainer::opts_, NnetChainExample::outputs, NnetTrainerOptions::print_interval, CuMatrixBase< Real >::Scale(), NnetChainSupervision::supervision, and kaldi::TraceMatMat().
Referenced by NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
| void Train | ( | const NnetChainExample & | eg | ) | 
Definition at line 60 of file nnet-chain-training.cc.
References NnetTrainerOptions::backstitch_training_interval, NnetTrainerOptions::backstitch_training_scale, NnetChainTrainingOptions::chain_config, CachingOptimizingCompiler::Compile(), NnetChainTrainer::compiler_, kaldi::nnet3::ConsolidateMemory(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::FreezeNaturalGradient(), kaldi::nnet3::GetChainComputationRequest(), KALDI_ASSERT, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::num_minibatches_processed_, NVTX_RANGE, NnetChainTrainer::opts_, kaldi::nnet3::ResetGenerators(), NnetChainTrainer::srand_seed_, NnetTrainerOptions::store_component_stats, NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
      
  | 
  private | 
Definition at line 97 of file nnet-chain-training.cc.
References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, kaldi::nnet3::ConstrainOrthonormal(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetChainExample::inputs, NnetTrainerOptions::l2_regularize_factor, NnetChainTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NVTX_RANGE, NnetChainTrainer::opts_, NnetChainTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().
Referenced by NnetChainTrainer::Train().
      
  | 
  private | 
Definition at line 143 of file nnet-chain-training.cc.
References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::backstitch_training_scale, NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, kaldi::nnet3::ConstrainOrthonormal(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetChainExample::inputs, NnetTrainerOptions::l2_regularize_factor, NnetChainTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::opts_, NnetChainTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().
Referenced by NnetChainTrainer::Train().
      
  | 
  private | 
Definition at line 89 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::NnetChainTrainer(), NnetChainTrainer::Train(), and NnetChainTrainer::~NnetChainTrainer().
      
  | 
  private | 
Definition at line 87 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::NnetChainTrainer(), NnetChainTrainer::Train(), NnetChainTrainer::TrainInternal(), NnetChainTrainer::TrainInternalBackstitch(), and NnetChainTrainer::~NnetChainTrainer().
      
  | 
  private | 
Definition at line 85 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::ProcessOutputs().
      
  | 
  private | 
Definition at line 97 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::PrintTotalStats(), NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
      
  | 
  private | 
Definition at line 86 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::NnetChainTrainer(), NnetChainTrainer::PrintTotalStats(), NnetChainTrainer::ProcessOutputs(), NnetChainTrainer::Train(), NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().
      
  | 
  private | 
Definition at line 94 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::ProcessOutputs(), and NnetChainTrainer::Train().
      
  | 
  private | 
Definition at line 99 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::PrintTotalStats(), and NnetChainTrainer::ProcessOutputs().
      
  | 
  private | 
Definition at line 83 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::ProcessOutputs(), NnetChainTrainer::Train(), NnetChainTrainer::TrainInternal(), NnetChainTrainer::TrainInternalBackstitch(), and NnetChainTrainer::~NnetChainTrainer().
      
  | 
  private | 
Definition at line 104 of file nnet-chain-training.h.
Referenced by NnetChainTrainer::Train().