This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics. More...
#include <nnet-diagnostics.h>
Public Member Functions | |
NnetComputeProb (const NnetComputeProbOptions &config, const Nnet &nnet) | |
NnetComputeProb (const NnetComputeProbOptions &config, Nnet *nnet) | |
void | Reset () |
void | Compute (const NnetExample &eg) |
bool | PrintTotalStats () const |
const SimpleObjectiveInfo * | GetObjective (const std::string &output_name) const |
double | GetTotalObjective (double *tot_weight) const |
const Nnet & | GetDeriv () const |
~NnetComputeProb () | |
Private Member Functions | |
void | ProcessOutputs (const NnetExample &eg, NnetComputer *computer) |
Private Attributes | |
NnetComputeProbOptions | config_ |
const Nnet & | nnet_ |
bool | deriv_nnet_owned_ |
Nnet * | deriv_nnet_ |
CachingOptimizingCompiler | compiler_ |
int32 | num_minibatches_processed_ |
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > | objf_info_ |
unordered_map< std::string, PerDimObjectiveInfo, StringHasher > | accuracy_info_ |
This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics.
Note: because we put a "logsoftmax" component in the nnet, the actual objective function becomes linear at the output, but the printed messages reflect the fact that it's the cross-entropy objective.
Definition at line 107 of file nnet-diagnostics.h.
NnetComputeProb | ( | const NnetComputeProbOptions & | config, |
const Nnet & | nnet | ||
) |
Definition at line 26 of file nnet-diagnostics.cc.
References NnetComputeProbOptions::compute_deriv, NnetComputeProb::config_, NnetComputeProb::deriv_nnet_, KALDI_ERR, NnetComputeProb::nnet_, kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetNnetAsGradient(), and NnetComputeProbOptions::store_component_stats.
NnetComputeProb | ( | const NnetComputeProbOptions & | config, |
Nnet * | nnet | ||
) |
Definition at line 45 of file nnet-diagnostics.cc.
References NnetComputeProbOptions::compute_deriv, KALDI_ASSERT, and NnetComputeProbOptions::store_component_stats.
~NnetComputeProb | ( | ) |
Definition at line 64 of file nnet-diagnostics.cc.
References NnetComputeProb::deriv_nnet_, and NnetComputeProb::deriv_nnet_owned_.
void Compute | ( | const NnetExample & | eg | ) |
Definition at line 79 of file nnet-diagnostics.cc.
References NnetComputer::AcceptInputs(), CachingOptimizingCompiler::Compile(), NnetComputeProb::compiler_, NnetComputeProbOptions::compute_config, NnetComputeProbOptions::compute_deriv, NnetComputeProb::config_, NnetComputeProb::deriv_nnet_, kaldi::nnet3::GetComputationRequest(), NnetExample::io, NnetComputeProb::nnet_, NnetComputeProb::ProcessOutputs(), NnetComputer::Run(), and NnetComputeProbOptions::store_component_stats.
Referenced by kaldi::nnet3::ComputeObjf(), main(), and kaldi::nnet3::RecomputeStats().
const Nnet & GetDeriv | ( | ) | const |
Definition at line 58 of file nnet-diagnostics.cc.
References NnetComputeProbOptions::compute_deriv, NnetComputeProb::config_, NnetComputeProb::deriv_nnet_, and KALDI_ERR.
Referenced by main().
const SimpleObjectiveInfo * GetObjective | ( | const std::string & | output_name | ) | const |
Definition at line 299 of file nnet-diagnostics.cc.
References NnetComputeProb::objf_info_.
Referenced by main().
double GetTotalObjective | ( | double * | tot_weight | ) | const |
Definition at line 309 of file nnet-diagnostics.cc.
References NnetComputeProb::objf_info_.
Referenced by kaldi::nnet3::ComputeObjf().
bool PrintTotalStats | ( | ) | const |
Definition at line 149 of file nnet-diagnostics.cc.
References NnetComputeProb::accuracy_info_, Nnet::GetNode(), Nnet::GetNodeIndex(), rnnlm::j, KALDI_ASSERT, KALDI_LOG, kaldi::nnet3::kLinear, NnetComputeProb::nnet_, NetworkNode::objective_type, NnetComputeProb::objf_info_, SimpleObjectiveInfo::tot_objective, PerDimObjectiveInfo::tot_objective_vec, SimpleObjectiveInfo::tot_weight, PerDimObjectiveInfo::tot_weight_vec, and NetworkNode::u.
Referenced by main(), and kaldi::nnet3::RecomputeStats().
|
private |
Definition at line 97 of file nnet-diagnostics.cc.
References NnetComputeProb::accuracy_info_, NnetComputeProbOptions::compute_accuracy, NnetComputeProbOptions::compute_deriv, NnetComputeProbOptions::compute_per_dim_accuracy, kaldi::nnet3::ComputeAccuracy(), kaldi::nnet3::ComputeObjectiveFunction(), NnetComputeProb::config_, NnetIo::features, Nnet::GetNode(), Nnet::GetNodeIndex(), NnetComputer::GetOutput(), NnetExample::io, Nnet::IsOutputNode(), KALDI_ERR, NnetIo::name, NnetComputeProb::nnet_, NnetComputeProb::num_minibatches_processed_, CuMatrixBase< Real >::NumCols(), GeneralMatrix::NumCols(), NetworkNode::objective_type, NnetComputeProb::objf_info_, SimpleObjectiveInfo::tot_objective, PerDimObjectiveInfo::tot_objective_vec, SimpleObjectiveInfo::tot_weight, PerDimObjectiveInfo::tot_weight_vec, and NetworkNode::u.
Referenced by NnetComputeProb::Compute().
void Reset | ( | ) |
Definition at line 69 of file nnet-diagnostics.cc.
References NnetComputeProb::accuracy_info_, NnetComputeProb::deriv_nnet_, NnetComputeProb::num_minibatches_processed_, NnetComputeProb::objf_info_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().
Referenced by kaldi::nnet3::ComputeObjf().
|
private |
Definition at line 161 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::PrintTotalStats(), NnetComputeProb::ProcessOutputs(), and NnetComputeProb::Reset().
|
private |
Definition at line 154 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::Compute().
|
private |
Definition at line 149 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::Compute(), NnetComputeProb::GetDeriv(), NnetComputeProb::NnetComputeProb(), and NnetComputeProb::ProcessOutputs().
|
private |
Definition at line 153 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::Compute(), NnetComputeProb::GetDeriv(), NnetComputeProb::NnetComputeProb(), NnetComputeProb::Reset(), and NnetComputeProb::~NnetComputeProb().
|
private |
Definition at line 152 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::~NnetComputeProb().
|
private |
Definition at line 150 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::Compute(), NnetComputeProb::NnetComputeProb(), NnetComputeProb::PrintTotalStats(), and NnetComputeProb::ProcessOutputs().
|
private |
Definition at line 157 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::ProcessOutputs(), and NnetComputeProb::Reset().
|
private |
Definition at line 159 of file nnet-diagnostics.h.
Referenced by NnetComputeProb::GetObjective(), NnetComputeProb::GetTotalObjective(), NnetComputeProb::PrintTotalStats(), NnetComputeProb::ProcessOutputs(), and NnetComputeProb::Reset().