This class is for single-threaded discriminative training of neural nets. More...
#include <nnet-discriminative-training.h>
Public Member Functions | |
NnetDiscriminativeTrainer (const NnetDiscriminativeOptions &config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, Nnet *nnet) | |
void | Train (const NnetDiscriminativeExample &eg) |
bool | PrintTotalStats () const |
~NnetDiscriminativeTrainer () | |
Private Member Functions | |
void | ProcessOutputs (const NnetDiscriminativeExample &eg, NnetComputer *computer) |
Private Attributes | |
const NnetDiscriminativeOptions | opts_ |
const TransitionModel & | tmodel_ |
CuVector< BaseFloat > | log_priors_ |
Nnet * | nnet_ |
Nnet * | delta_nnet_ |
CachingOptimizingCompiler | compiler_ |
int32 | num_minibatches_processed_ |
unordered_map< std::string, DiscriminativeObjectiveFunctionInfo, StringHasher > | objf_info_ |
This class is for single-threaded discriminative training of neural nets.
Definition at line 87 of file nnet-discriminative-training.h.
NnetDiscriminativeTrainer | ( | const NnetDiscriminativeOptions & | config, |
const TransitionModel & | tmodel, | ||
const VectorBase< BaseFloat > & | priors, | ||
Nnet * | nnet | ||
) |
Definition at line 27 of file nnet-discriminative-training.cc.
References NnetDiscriminativeTrainer::compiler_, Nnet::Copy(), NnetDiscriminativeTrainer::delta_nnet_, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, NnetDiscriminativeTrainer::log_priors_, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetDiscriminativeTrainer::nnet_, NnetDiscriminativeOptions::nnet_config, Input::Open(), NnetTrainerOptions::read_cache, CachingOptimizingCompiler::ReadCache(), kaldi::nnet3::ScaleNnet(), Input::Stream(), NnetTrainerOptions::zero_component_stats, and kaldi::nnet3::ZeroComponentStats().
Definition at line 264 of file nnet-discriminative-training.cc.
References NnetTrainerOptions::binary_write_cache, NnetDiscriminativeTrainer::compiler_, NnetDiscriminativeTrainer::delta_nnet_, NnetDiscriminativeOptions::nnet_config, NnetDiscriminativeTrainer::opts_, Output::Stream(), NnetTrainerOptions::write_cache, and CachingOptimizingCompiler::WriteCache().
bool PrintTotalStats | ( | ) | const |
Definition at line 191 of file nnet-discriminative-training.cc.
References DiscriminativeOptions::criterion, NnetDiscriminativeOptions::discriminative_config, NnetDiscriminativeTrainer::objf_info_, NnetDiscriminativeTrainer::opts_, and DiscriminativeObjectiveFunctionInfo::PrintTotalStats().
Referenced by main().
|
private |
Definition at line 109 of file nnet-discriminative-training.cc.
References NnetComputer::AcceptInput(), NnetDiscriminativeOptions::apply_deriv_weights, kaldi::discriminative::ComputeDiscriminativeObjfAndDeriv(), DiscriminativeOptions::criterion, NnetDiscriminativeSupervision::deriv_weights, NnetDiscriminativeOptions::discriminative_config, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, NnetDiscriminativeTrainer::log_priors_, CuMatrixBase< Real >::MulRowsVec(), NnetDiscriminativeSupervision::name, NnetDiscriminativeTrainer::nnet_, NnetDiscriminativeOptions::nnet_config, NnetDiscriminativeTrainer::num_minibatches_processed_, NnetDiscriminativeTrainer::objf_info_, NnetDiscriminativeTrainer::opts_, NnetDiscriminativeExample::outputs, NnetTrainerOptions::print_interval, CuMatrix< Real >::Resize(), CuMatrixBase< Real >::Scale(), NnetDiscriminativeSupervision::supervision, NnetDiscriminativeTrainer::tmodel_, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t_weighted, kaldi::TraceMatMat(), and DiscriminativeOptions::xent_regularize.
Referenced by NnetDiscriminativeTrainer::Train().
void Train | ( | const NnetDiscriminativeExample & | eg | ) |
Definition at line 63 of file nnet-discriminative-training.cc.
References NnetComputer::AcceptInputs(), kaldi::nnet3::AddNnet(), CachingOptimizingCompiler::Compile(), NnetDiscriminativeTrainer::compiler_, NnetTrainerOptions::compute_config, NnetDiscriminativeTrainer::delta_nnet_, NnetDiscriminativeOptions::discriminative_config, kaldi::nnet3::DotProduct(), kaldi::nnet3::GetDiscriminativeComputationRequest(), NnetDiscriminativeExample::inputs, KALDI_LOG, KALDI_WARN, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetDiscriminativeTrainer::nnet_, NnetDiscriminativeOptions::nnet_config, NnetDiscriminativeTrainer::opts_, NnetDiscriminativeTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleNnet(), NnetTrainerOptions::store_component_stats, and DiscriminativeOptions::xent_regularize.
Referenced by main().
|
private |
Definition at line 116 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::NnetDiscriminativeTrainer(), NnetDiscriminativeTrainer::Train(), and NnetDiscriminativeTrainer::~NnetDiscriminativeTrainer().
|
private |
Definition at line 112 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::NnetDiscriminativeTrainer(), NnetDiscriminativeTrainer::Train(), and NnetDiscriminativeTrainer::~NnetDiscriminativeTrainer().
Definition at line 108 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::NnetDiscriminativeTrainer(), and NnetDiscriminativeTrainer::ProcessOutputs().
|
private |
Definition at line 110 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::NnetDiscriminativeTrainer(), NnetDiscriminativeTrainer::ProcessOutputs(), and NnetDiscriminativeTrainer::Train().
|
private |
Definition at line 118 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::ProcessOutputs().
|
private |
Definition at line 123 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::PrintTotalStats(), and NnetDiscriminativeTrainer::ProcessOutputs().
|
private |
Definition at line 105 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::PrintTotalStats(), NnetDiscriminativeTrainer::ProcessOutputs(), NnetDiscriminativeTrainer::Train(), and NnetDiscriminativeTrainer::~NnetDiscriminativeTrainer().
|
private |
Definition at line 107 of file nnet-discriminative-training.h.
Referenced by NnetDiscriminativeTrainer::ProcessOutputs().