This class is for computing objective-function values in a nnet3+chain setup, for diagnostics. More...
#include <nnet-chain-diagnostics.h>
Public Member Functions | |
NnetChainComputeProb (const NnetComputeProbOptions &nnet_config, const chain::ChainTrainingOptions &chain_config, const fst::StdVectorFst &den_fst, const Nnet &nnet) | |
NnetChainComputeProb (const NnetComputeProbOptions &nnet_config, const chain::ChainTrainingOptions &chain_config, const fst::StdVectorFst &den_fst, Nnet *nnet) | |
void | Reset () |
void | Compute (const NnetChainExample &chain_eg) |
bool | PrintTotalStats () const |
const ChainObjectiveInfo * | GetObjective (const std::string &output_name) const |
double | GetTotalObjective (double *tot_weight) const |
const Nnet & | GetDeriv () const |
~NnetChainComputeProb () | |
Private Member Functions | |
void | ProcessOutputs (const NnetChainExample &chain_eg, NnetComputer *computer) |
Private Attributes | |
NnetComputeProbOptions | nnet_config_ |
chain::ChainTrainingOptions | chain_config_ |
chain::DenominatorGraph | den_graph_ |
const Nnet & | nnet_ |
CachingOptimizingCompiler | compiler_ |
bool | deriv_nnet_owned_ |
Nnet * | deriv_nnet_ |
int32 | num_minibatches_processed_ |
unordered_map< std::string, ChainObjectiveInfo, StringHasher > | objf_info_ |
This class is for computing objective-function values in a nnet3+chain setup, for diagnostics.
It also supports computing model derivatives. Note: if the –xent-regularization option is nonzero, the cross-entropy objective will be computed, and displayed when you call PrintTotalStats(), but it will not contribute to model derivatives (there is no code to compute the regularized objective function, and anyway it's not clear that we really need this regularization in the combination phase).
Definition at line 54 of file nnet-chain-diagnostics.h.
NnetChainComputeProb | ( | const NnetComputeProbOptions & | nnet_config, |
const chain::ChainTrainingOptions & | chain_config, | ||
const fst::StdVectorFst & | den_fst, | ||
const Nnet & | nnet | ||
) |
Definition at line 26 of file nnet-chain-diagnostics.cc.
References NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::deriv_nnet_, KALDI_ERR, NnetChainComputeProb::nnet_, NnetChainComputeProb::nnet_config_, kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetNnetAsGradient(), and NnetComputeProbOptions::store_component_stats.
NnetChainComputeProb | ( | const NnetComputeProbOptions & | nnet_config, |
const chain::ChainTrainingOptions & | chain_config, | ||
const fst::StdVectorFst & | den_fst, | ||
Nnet * | nnet | ||
) |
Definition at line 50 of file nnet-chain-diagnostics.cc.
References NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::den_graph_, KALDI_ASSERT, and NnetComputeProbOptions::store_component_stats.
~NnetChainComputeProb | ( | ) |
Definition at line 74 of file nnet-chain-diagnostics.cc.
References NnetChainComputeProb::deriv_nnet_, and NnetChainComputeProb::deriv_nnet_owned_.
void Compute | ( | const NnetChainExample & | chain_eg | ) |
Definition at line 88 of file nnet-chain-diagnostics.cc.
References NnetComputer::AcceptInputs(), NnetChainComputeProb::chain_config_, CachingOptimizingCompiler::Compile(), NnetChainComputeProb::compiler_, NnetComputeProbOptions::compute_config, NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::deriv_nnet_, kaldi::nnet3::GetChainComputationRequest(), NnetChainExample::inputs, NnetChainComputeProb::nnet_, NnetChainComputeProb::nnet_config_, NnetChainComputeProb::ProcessOutputs(), NnetComputer::Run(), and NnetComputeProbOptions::store_component_stats.
Referenced by kaldi::nnet3::RecomputeStats().
const Nnet & GetDeriv | ( | ) | const |
Definition at line 68 of file nnet-chain-diagnostics.cc.
References NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::deriv_nnet_, KALDI_ERR, and NnetChainComputeProb::nnet_config_.
const ChainObjectiveInfo * GetObjective | ( | const std::string & | output_name | ) | const |
Definition at line 211 of file nnet-chain-diagnostics.cc.
References NnetChainComputeProb::objf_info_.
double GetTotalObjective | ( | double * | tot_weight | ) | const |
Definition at line 221 of file nnet-chain-diagnostics.cc.
References NnetChainComputeProb::objf_info_.
bool PrintTotalStats | ( | ) | const |
Definition at line 179 of file nnet-chain-diagnostics.cc.
References Nnet::GetNodeIndex(), KALDI_ASSERT, KALDI_LOG, NnetChainComputeProb::nnet_, NnetChainComputeProb::objf_info_, ChainObjectiveInfo::tot_l2_term, ChainObjectiveInfo::tot_like, and ChainObjectiveInfo::tot_weight.
Referenced by kaldi::nnet3::RecomputeStats().
|
private |
Definition at line 115 of file nnet-chain-diagnostics.cc.
References NnetComputer::AcceptInput(), NnetChainComputeProb::chain_config_, NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::den_graph_, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, NnetChainSupervision::name, NnetChainComputeProb::nnet_, NnetChainComputeProb::nnet_config_, NnetChainComputeProb::num_minibatches_processed_, NnetChainComputeProb::objf_info_, NnetChainExample::outputs, CuMatrix< Real >::Resize(), NnetChainSupervision::supervision, ChainObjectiveInfo::tot_l2_term, ChainObjectiveInfo::tot_like, ChainObjectiveInfo::tot_weight, and kaldi::TraceMatMat().
Referenced by NnetChainComputeProb::Compute().
void Reset | ( | ) |
Definition at line 79 of file nnet-chain-diagnostics.cc.
References NnetChainComputeProb::deriv_nnet_, NnetChainComputeProb::num_minibatches_processed_, NnetChainComputeProb::objf_info_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().
|
private |
Definition at line 101 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::Compute(), and NnetChainComputeProb::ProcessOutputs().
|
private |
Definition at line 104 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::Compute().
|
private |
Definition at line 102 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::NnetChainComputeProb(), and NnetChainComputeProb::ProcessOutputs().
|
private |
Definition at line 106 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::Compute(), NnetChainComputeProb::GetDeriv(), NnetChainComputeProb::NnetChainComputeProb(), NnetChainComputeProb::Reset(), and NnetChainComputeProb::~NnetChainComputeProb().
|
private |
Definition at line 105 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::~NnetChainComputeProb().
|
private |
Definition at line 103 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::Compute(), NnetChainComputeProb::NnetChainComputeProb(), NnetChainComputeProb::PrintTotalStats(), and NnetChainComputeProb::ProcessOutputs().
|
private |
Definition at line 100 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::Compute(), NnetChainComputeProb::GetDeriv(), NnetChainComputeProb::NnetChainComputeProb(), and NnetChainComputeProb::ProcessOutputs().
|
private |
Definition at line 107 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::ProcessOutputs(), and NnetChainComputeProb::Reset().
|
private |
Definition at line 109 of file nnet-chain-diagnostics.h.
Referenced by NnetChainComputeProb::GetObjective(), NnetChainComputeProb::GetTotalObjective(), NnetChainComputeProb::PrintTotalStats(), NnetChainComputeProb::ProcessOutputs(), and NnetChainComputeProb::Reset().