This class is for single-threaded training of neural nets using standard objective functions such as cross-entropy (implemented with logsoftmax nonlinearity and a linear objective function) and quadratic loss. More...
#include <nnet-training.h>
Public Member Functions | |
NnetTrainer (const NnetTrainerOptions &config, Nnet *nnet) | |
void | Train (const NnetExample &eg) |
bool | PrintTotalStats () const |
~NnetTrainer () | |
Private Member Functions | |
void | TrainInternal (const NnetExample &eg, const NnetComputation &computation) |
void | TrainInternalBackstitch (const NnetExample &eg, const NnetComputation &computation, bool is_backstitch_step1) |
void | ProcessOutputs (bool is_backstitch_step2, const NnetExample &eg, NnetComputer *computer) |
Private Attributes | |
const NnetTrainerOptions | config_ |
Nnet * | nnet_ |
Nnet * | delta_nnet_ |
CachingOptimizingCompiler | compiler_ |
int32 | num_minibatches_processed_ |
MaxChangeStats | max_change_stats_ |
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > | objf_info_ |
int32 | srand_seed_ |
This class is for single-threaded training of neural nets using standard objective functions such as cross-entropy (implemented with logsoftmax nonlinearity and a linear objective function) and quadratic loss.
Something that we should do in the future is to make it possible to have two different threads, one for the compilation, and one for the computation. This would only improve efficiency in the cases where the structure of the input example was different each time, which isn't what we expect to see in speech-recognition training. (If the structure is the same each time, the CachingOptimizingCompiler notices this and uses the computation from last time).
Definition at line 180 of file nnet-training.h.
NnetTrainer | ( | const NnetTrainerOptions & | config, |
Nnet * | nnet | ||
) |
Definition at line 27 of file nnet-training.cc.
References NnetTrainerOptions::backstitch_training_interval, NnetTrainer::compiler_, NnetTrainer::config_, Nnet::Copy(), NnetTrainer::delta_nnet_, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetTrainer::nnet_, Input::Open(), NnetTrainerOptions::read_cache, CachingOptimizingCompiler::ReadCache(), kaldi::nnet3::ScaleNnet(), Input::Stream(), NnetTrainerOptions::zero_component_stats, and kaldi::nnet3::ZeroComponentStats().
~NnetTrainer | ( | ) |
Definition at line 330 of file nnet-training.cc.
References NnetTrainerOptions::binary_write_cache, NnetTrainer::compiler_, NnetTrainer::config_, NnetTrainer::delta_nnet_, KALDI_LOG, Output::Stream(), NnetTrainerOptions::write_cache, and CachingOptimizingCompiler::WriteCache().
bool PrintTotalStats | ( | ) | const |
Definition at line 220 of file nnet-training.cc.
References rnnlm::i, NnetTrainer::max_change_stats_, NnetTrainer::nnet_, NnetTrainer::objf_info_, MaxChangeStats::Print(), and ObjectiveFunctionInfo::PrintTotalStats().
Referenced by main().
|
private |
Definition at line 191 of file nnet-training.cc.
References kaldi::nnet3::ComputeObjectiveFunction(), NnetTrainer::config_, NnetIo::features, Nnet::GetNode(), Nnet::GetNodeIndex(), NnetExample::io, Nnet::IsOutputNode(), KALDI_ASSERT, NnetIo::name, NnetTrainer::nnet_, NnetTrainer::num_minibatches_processed_, NetworkNode::objective_type, NnetTrainer::objf_info_, NnetTrainerOptions::print_interval, and NetworkNode::u.
Referenced by NnetTrainer::TrainInternal(), and NnetTrainer::TrainInternalBackstitch().
void Train | ( | const NnetExample & | eg | ) |
Definition at line 57 of file nnet-training.cc.
References NnetTrainerOptions::backstitch_training_interval, NnetTrainerOptions::backstitch_training_scale, CachingOptimizingCompiler::Compile(), NnetTrainer::compiler_, NnetTrainer::config_, kaldi::nnet3::ConsolidateMemory(), NnetTrainer::delta_nnet_, kaldi::nnet3::FreezeNaturalGradient(), kaldi::nnet3::GetComputationRequest(), KALDI_ASSERT, NnetTrainerOptions::momentum, NnetTrainer::nnet_, NnetTrainer::num_minibatches_processed_, kaldi::nnet3::ResetGenerators(), NnetTrainer::srand_seed_, NnetTrainerOptions::store_component_stats, NnetTrainer::TrainInternal(), and NnetTrainer::TrainInternalBackstitch().
Referenced by main().
|
private |
Definition at line 91 of file nnet-training.cc.
References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, NnetTrainer::config_, kaldi::nnet3::ConstrainOrthonormal(), NnetTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetExample::io, NnetTrainerOptions::l2_regularize_factor, NnetTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetTrainer::nnet_, NnetTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().
Referenced by NnetTrainer::Train().
|
private |
Definition at line 131 of file nnet-training.cc.
References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::backstitch_training_scale, NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, NnetTrainer::config_, kaldi::nnet3::ConstrainOrthonormal(), NnetTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetExample::io, NnetTrainerOptions::l2_regularize_factor, NnetTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetTrainer::nnet_, NnetTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().
Referenced by NnetTrainer::Train().
|
private |
Definition at line 212 of file nnet-training.h.
Referenced by NnetTrainer::NnetTrainer(), NnetTrainer::Train(), and NnetTrainer::~NnetTrainer().
|
private |
Definition at line 207 of file nnet-training.h.
Referenced by NnetTrainer::NnetTrainer(), NnetTrainer::ProcessOutputs(), NnetTrainer::Train(), NnetTrainer::TrainInternal(), NnetTrainer::TrainInternalBackstitch(), and NnetTrainer::~NnetTrainer().
|
private |
Definition at line 209 of file nnet-training.h.
Referenced by NnetTrainer::NnetTrainer(), NnetTrainer::Train(), NnetTrainer::TrainInternal(), NnetTrainer::TrainInternalBackstitch(), and NnetTrainer::~NnetTrainer().
|
private |
Definition at line 220 of file nnet-training.h.
Referenced by NnetTrainer::PrintTotalStats(), NnetTrainer::TrainInternal(), and NnetTrainer::TrainInternalBackstitch().
|
private |
Definition at line 208 of file nnet-training.h.
Referenced by NnetTrainer::NnetTrainer(), NnetTrainer::PrintTotalStats(), NnetTrainer::ProcessOutputs(), NnetTrainer::Train(), NnetTrainer::TrainInternal(), and NnetTrainer::TrainInternalBackstitch().
|
private |
Definition at line 217 of file nnet-training.h.
Referenced by NnetTrainer::ProcessOutputs(), and NnetTrainer::Train().
|
private |
Definition at line 222 of file nnet-training.h.
Referenced by NnetTrainer::PrintTotalStats(), and NnetTrainer::ProcessOutputs().
|
private |
Definition at line 227 of file nnet-training.h.
Referenced by NnetTrainer::Train().