This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics. More...
#include <nnet-discriminative-diagnostics.h>
Public Member Functions | |
NnetDiscriminativeComputeObjf (const NnetComputeProbOptions &nnet_config, const discriminative::DiscriminativeOptions &discriminative_config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, const Nnet &nnet) | |
void | Reset () |
void | Compute (const NnetDiscriminativeExample &eg) |
bool | PrintTotalStats () const |
const discriminative::DiscriminativeObjectiveInfo * | GetObjective (const std::string &output_name) const |
const Nnet & | GetDeriv () const |
~NnetDiscriminativeComputeObjf () | |
Private Member Functions | |
void | ProcessOutputs (const NnetDiscriminativeExample &eg, NnetComputer *computer) |
Private Attributes | |
NnetComputeProbOptions | nnet_config_ |
discriminative::DiscriminativeOptions | discriminative_config_ |
const TransitionModel & | tmodel_ |
CuVector< BaseFloat > | log_priors_ |
const Nnet & | nnet_ |
CachingOptimizingCompiler | compiler_ |
Nnet * | deriv_nnet_ |
int32 | num_minibatches_processed_ |
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher > | objf_info_ |
This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics.
It also supports computing model derivatives.
Definition at line 39 of file nnet-discriminative-diagnostics.h.
NnetDiscriminativeComputeObjf | ( | const NnetComputeProbOptions & | nnet_config, |
const discriminative::DiscriminativeOptions & | discriminative_config, | ||
const TransitionModel & | tmodel, | ||
const VectorBase< BaseFloat > & | priors, | ||
const Nnet & | nnet | ||
) |
Definition at line 28 of file nnet-discriminative-diagnostics.cc.
References NnetComputeProbOptions::compute_deriv, NnetDiscriminativeComputeObjf::deriv_nnet_, NnetDiscriminativeComputeObjf::log_priors_, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::nnet_config_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().
Definition at line 56 of file nnet-discriminative-diagnostics.cc.
References NnetDiscriminativeComputeObjf::deriv_nnet_.
void Compute | ( | const NnetDiscriminativeExample & | eg | ) |
Definition at line 69 of file nnet-discriminative-diagnostics.cc.
References NnetComputer::AcceptInputs(), CachingOptimizingCompiler::Compile(), NnetDiscriminativeComputeObjf::compiler_, NnetComputeProbOptions::compute_config, NnetComputeProbOptions::compute_deriv, NnetDiscriminativeComputeObjf::deriv_nnet_, NnetDiscriminativeComputeObjf::discriminative_config_, kaldi::nnet3::GetDiscriminativeComputationRequest(), NnetDiscriminativeExample::inputs, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::nnet_config_, NnetDiscriminativeComputeObjf::ProcessOutputs(), NnetComputer::Run(), and DiscriminativeOptions::xent_regularize.
const Nnet & GetDeriv | ( | ) | const |
Definition at line 50 of file nnet-discriminative-diagnostics.cc.
References NnetDiscriminativeComputeObjf::deriv_nnet_, and KALDI_ERR.
const discriminative::DiscriminativeObjectiveInfo * GetObjective | ( | const std::string & | output_name | ) | const |
Definition at line 197 of file nnet-discriminative-diagnostics.cc.
References NnetDiscriminativeComputeObjf::objf_info_.
bool PrintTotalStats | ( | ) | const |
Definition at line 157 of file nnet-discriminative-diagnostics.cc.
References DiscriminativeOptions::criterion, NnetDiscriminativeComputeObjf::discriminative_config_, Nnet::GetNodeIndex(), KALDI_ASSERT, KALDI_LOG, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::objf_info_, DiscriminativeObjectiveInfo::PrintAll(), DiscriminativeObjectiveInfo::tot_l2_term, DiscriminativeObjectiveInfo::tot_t_weighted, and DiscriminativeObjectiveInfo::TotalObjf().
|
private |
Definition at line 92 of file nnet-discriminative-diagnostics.cc.
References NnetComputer::AcceptInput(), NnetComputeProbOptions::compute_deriv, kaldi::discriminative::ComputeDiscriminativeObjfAndDeriv(), NnetDiscriminativeComputeObjf::discriminative_config_, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, NnetDiscriminativeComputeObjf::log_priors_, NnetDiscriminativeSupervision::name, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::nnet_config_, NnetDiscriminativeComputeObjf::num_minibatches_processed_, NnetDiscriminativeComputeObjf::objf_info_, NnetDiscriminativeExample::outputs, CuMatrix< Real >::Resize(), NnetDiscriminativeSupervision::supervision, NnetDiscriminativeComputeObjf::tmodel_, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t_weighted, kaldi::TraceMatMat(), and DiscriminativeOptions::xent_regularize.
Referenced by NnetDiscriminativeComputeObjf::Compute().
void Reset | ( | ) |
Definition at line 60 of file nnet-discriminative-diagnostics.cc.
References NnetDiscriminativeComputeObjf::deriv_nnet_, NnetDiscriminativeComputeObjf::num_minibatches_processed_, NnetDiscriminativeComputeObjf::objf_info_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().
|
private |
Definition at line 77 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::Compute().
|
private |
Definition at line 78 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::Compute(), NnetDiscriminativeComputeObjf::GetDeriv(), NnetDiscriminativeComputeObjf::NnetDiscriminativeComputeObjf(), NnetDiscriminativeComputeObjf::Reset(), and NnetDiscriminativeComputeObjf::~NnetDiscriminativeComputeObjf().
|
private |
Definition at line 73 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::Compute(), NnetDiscriminativeComputeObjf::PrintTotalStats(), and NnetDiscriminativeComputeObjf::ProcessOutputs().
Definition at line 75 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::NnetDiscriminativeComputeObjf(), and NnetDiscriminativeComputeObjf::ProcessOutputs().
|
private |
|
private |
Definition at line 71 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::Compute(), NnetDiscriminativeComputeObjf::NnetDiscriminativeComputeObjf(), and NnetDiscriminativeComputeObjf::ProcessOutputs().
|
private |
Definition at line 79 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::ProcessOutputs(), and NnetDiscriminativeComputeObjf::Reset().
|
private |
Definition at line 81 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::GetObjective(), NnetDiscriminativeComputeObjf::PrintTotalStats(), NnetDiscriminativeComputeObjf::ProcessOutputs(), and NnetDiscriminativeComputeObjf::Reset().
|
private |
Definition at line 74 of file nnet-discriminative-diagnostics.h.
Referenced by NnetDiscriminativeComputeObjf::ProcessOutputs().