NnetChainComputeProb Class Reference

This class is for computing objective-function values in a nnet3+chain setup, for diagnostics. More...

#include <nnet-chain-diagnostics.h>

Collaboration diagram for NnetChainComputeProb:

Public Member Functions

 NnetChainComputeProb (const NnetComputeProbOptions &nnet_config, const chain::ChainTrainingOptions &chain_config, const fst::StdVectorFst &den_fst, const Nnet &nnet)
 
 NnetChainComputeProb (const NnetComputeProbOptions &nnet_config, const chain::ChainTrainingOptions &chain_config, const fst::StdVectorFst &den_fst, Nnet *nnet)
 
void Reset ()
 
void Compute (const NnetChainExample &chain_eg)
 
bool PrintTotalStats () const
 
const ChainObjectiveInfoGetObjective (const std::string &output_name) const
 
double GetTotalObjective (double *tot_weight) const
 
const NnetGetDeriv () const
 
 ~NnetChainComputeProb ()
 

Private Member Functions

void ProcessOutputs (const NnetChainExample &chain_eg, NnetComputer *computer)
 

Private Attributes

NnetComputeProbOptions nnet_config_
 
chain::ChainTrainingOptions chain_config_
 
chain::DenominatorGraph den_graph_
 
const Nnetnnet_
 
CachingOptimizingCompiler compiler_
 
bool deriv_nnet_owned_
 
Nnetderiv_nnet_
 
int32 num_minibatches_processed_
 
unordered_map< std::string, ChainObjectiveInfo, StringHasherobjf_info_
 

Detailed Description

This class is for computing objective-function values in a nnet3+chain setup, for diagnostics.

It also supports computing model derivatives. Note: if the –xent-regularization option is nonzero, the cross-entropy objective will be computed, and displayed when you call PrintTotalStats(), but it will not contribute to model derivatives (there is no code to compute the regularized objective function, and anyway it's not clear that we really need this regularization in the combination phase).

Definition at line 54 of file nnet-chain-diagnostics.h.

Constructor & Destructor Documentation

◆ NnetChainComputeProb() [1/2]

NnetChainComputeProb ( const NnetComputeProbOptions nnet_config,
const chain::ChainTrainingOptions &  chain_config,
const fst::StdVectorFst den_fst,
const Nnet nnet 
)

Definition at line 26 of file nnet-chain-diagnostics.cc.

References NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::deriv_nnet_, KALDI_ERR, NnetChainComputeProb::nnet_, NnetChainComputeProb::nnet_config_, kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetNnetAsGradient(), and NnetComputeProbOptions::store_component_stats.

30  :
31  nnet_config_(nnet_config),
32  chain_config_(chain_config),
33  den_graph_(den_fst, nnet.OutputDim("output")),
34  nnet_(nnet),
36  deriv_nnet_owned_(true),
37  deriv_nnet_(NULL),
40  deriv_nnet_ = new Nnet(nnet_);
41  ScaleNnet(0.0, deriv_nnet_);
42  SetNnetAsGradient(deriv_nnet_); // force simple update
44  KALDI_ERR << "If you set store_component_stats == true and "
45  << "compute_deriv == false, use the other constructor.";
46  }
47 }
chain::ChainTrainingOptions chain_config_
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292
#define KALDI_ERR
Definition: kaldi-error.h:147
CachingOptimizingCompilerOptions compiler_config

◆ NnetChainComputeProb() [2/2]

NnetChainComputeProb ( const NnetComputeProbOptions nnet_config,
const chain::ChainTrainingOptions &  chain_config,
const fst::StdVectorFst den_fst,
Nnet nnet 
)

Definition at line 50 of file nnet-chain-diagnostics.cc.

References NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::den_graph_, KALDI_ASSERT, and NnetComputeProbOptions::store_component_stats.

54  :
55  nnet_config_(nnet_config),
56  chain_config_(chain_config),
57  den_graph_(den_fst, nnet->OutputDim("output")),
58  nnet_(*nnet),
60  deriv_nnet_owned_(false),
61  deriv_nnet_(nnet),
63  KALDI_ASSERT(den_graph_.NumPdfs() > 0);
64  KALDI_ASSERT(nnet_config.store_component_stats && !nnet_config.compute_deriv);
65 }
chain::ChainTrainingOptions chain_config_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
CachingOptimizingCompilerOptions compiler_config

◆ ~NnetChainComputeProb()

Definition at line 74 of file nnet-chain-diagnostics.cc.

References NnetChainComputeProb::deriv_nnet_, and NnetChainComputeProb::deriv_nnet_owned_.

74  {
76  delete deriv_nnet_; // delete does nothing if pointer is NULL.
77 }

Member Function Documentation

◆ Compute()

void Compute ( const NnetChainExample chain_eg)

Definition at line 88 of file nnet-chain-diagnostics.cc.

References NnetComputer::AcceptInputs(), NnetChainComputeProb::chain_config_, CachingOptimizingCompiler::Compile(), NnetChainComputeProb::compiler_, NnetComputeProbOptions::compute_config, NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::deriv_nnet_, kaldi::nnet3::GetChainComputationRequest(), NnetChainExample::inputs, NnetChainComputeProb::nnet_, NnetChainComputeProb::nnet_config_, NnetChainComputeProb::ProcessOutputs(), NnetComputer::Run(), and NnetComputeProbOptions::store_component_stats.

Referenced by kaldi::nnet3::RecomputeStats().

88  {
89  bool need_model_derivative = nnet_config_.compute_deriv,
90  store_component_stats = nnet_config_.store_component_stats;
91  ComputationRequest request;
92  // if the options specify cross-entropy regularization, we'll be computing
93  // this objective (not interpolated with the regular objective-- we give it a
94  // separate name), but currently we won't make it contribute to the
95  // derivative-- we just compute the derivative of the regular output.
96  // This is because in the place where we use the derivative (the
97  // model-combination code) we decided to keep it simple and just use the
98  // regular objective.
99  bool use_xent_regularization = (chain_config_.xent_regularize != 0.0),
100  use_xent_derivative = false;
101  GetChainComputationRequest(nnet_, chain_eg, need_model_derivative,
102  store_component_stats, use_xent_regularization,
103  use_xent_derivative, &request);
104  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
105  NnetComputer computer(nnet_config_.compute_config, *computation,
106  nnet_, deriv_nnet_);
107  // give the inputs to the computer object.
108  computer.AcceptInputs(nnet_, chain_eg.inputs);
109  computer.Run();
110  this->ProcessOutputs(chain_eg, &computer);
112  computer.Run();
113 }
chain::ChainTrainingOptions chain_config_
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.
void ProcessOutputs(const NnetChainExample &chain_eg, NnetComputer *computer)

◆ GetDeriv()

const Nnet & GetDeriv ( ) const

Definition at line 68 of file nnet-chain-diagnostics.cc.

References NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::deriv_nnet_, KALDI_ERR, and NnetChainComputeProb::nnet_config_.

68  {
70  KALDI_ERR << "GetDeriv() called when no derivatives were requested.";
71  return *deriv_nnet_;
72 }
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ GetObjective()

const ChainObjectiveInfo * GetObjective ( const std::string &  output_name) const

Definition at line 211 of file nnet-chain-diagnostics.cc.

References NnetChainComputeProb::objf_info_.

212  {
213  unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
214  iter = objf_info_.find(output_name);
215  if (iter != objf_info_.end())
216  return &(iter->second);
217  else
218  return NULL;
219 }
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_

◆ GetTotalObjective()

double GetTotalObjective ( double *  tot_weight) const

Definition at line 221 of file nnet-chain-diagnostics.cc.

References NnetChainComputeProb::objf_info_.

221  {
222  double tot_objectives = 0.0;
223  double tot_weight = 0.0;
224  unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
225  iter = objf_info_.begin(), end = objf_info_.end();
226  for (; iter != end; ++iter) {
227  tot_objectives += iter->second.tot_like + iter->second.tot_l2_term;
228  tot_weight += iter->second.tot_weight;
229  }
230 
231  if (total_weight) *total_weight = tot_weight;
232  return tot_objectives;
233 }
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_

◆ PrintTotalStats()

bool PrintTotalStats ( ) const

Definition at line 179 of file nnet-chain-diagnostics.cc.

References Nnet::GetNodeIndex(), KALDI_ASSERT, KALDI_LOG, NnetChainComputeProb::nnet_, NnetChainComputeProb::objf_info_, ChainObjectiveInfo::tot_l2_term, ChainObjectiveInfo::tot_like, and ChainObjectiveInfo::tot_weight.

Referenced by kaldi::nnet3::RecomputeStats().

179  {
180  bool ans = false;
181  unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
182  iter, end;
183  iter = objf_info_.begin();
184  end = objf_info_.end();
185  for (; iter != end; ++iter) {
186  const std::string &name = iter->first;
187  int32 node_index = nnet_.GetNodeIndex(name);
188  KALDI_ASSERT(node_index >= 0);
189  const ChainObjectiveInfo &info = iter->second;
190  BaseFloat like = (info.tot_like / info.tot_weight),
191  l2_term = (info.tot_l2_term / info.tot_weight),
192  tot_objf = like + l2_term;
193  if (info.tot_l2_term == 0.0) {
194  KALDI_LOG << "Overall log-probability for '"
195  << name << "' is "
196  << like << " per frame"
197  << ", over " << info.tot_weight << " frames.";
198  } else {
199  KALDI_LOG << "Overall log-probability for '"
200  << name << "' is "
201  << like << " + " << l2_term << " = " << tot_objf << " per frame"
202  << ", over " << info.tot_weight << " frames.";
203  }
204  if (info.tot_weight > 0)
205  ans = true;
206  }
207  return ans;
208 }
kaldi::int32 int32
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ ProcessOutputs()

void ProcessOutputs ( const NnetChainExample chain_eg,
NnetComputer computer 
)
private

Definition at line 115 of file nnet-chain-diagnostics.cc.

References NnetComputer::AcceptInput(), NnetChainComputeProb::chain_config_, NnetComputeProbOptions::compute_deriv, NnetChainComputeProb::den_graph_, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, NnetChainSupervision::name, NnetChainComputeProb::nnet_, NnetChainComputeProb::nnet_config_, NnetChainComputeProb::num_minibatches_processed_, NnetChainComputeProb::objf_info_, NnetChainExample::outputs, CuMatrix< Real >::Resize(), NnetChainSupervision::supervision, ChainObjectiveInfo::tot_l2_term, ChainObjectiveInfo::tot_like, ChainObjectiveInfo::tot_weight, and kaldi::TraceMatMat().

Referenced by NnetChainComputeProb::Compute().

116  {
117  // There will normally be just one output here, named 'output',
118  // but the code is more general than this.
119  std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(),
120  end = eg.outputs.end();
121  for (; iter != end; ++iter) {
122  const NnetChainSupervision &sup = *iter;
123  int32 node_index = nnet_.GetNodeIndex(sup.name);
124  if (node_index < 0 ||
125  !nnet_.IsOutputNode(node_index))
126  KALDI_ERR << "Network has no output named " << sup.name;
127 
128  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
129  bool use_xent = (chain_config_.xent_regularize != 0.0);
130  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
131  CuMatrix<BaseFloat> nnet_output_deriv, xent_deriv;
133  nnet_output_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
134  kUndefined);
135  if (use_xent)
136  xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
137  kUndefined);
138 
139  BaseFloat tot_like, tot_l2_term, tot_weight;
140 
141  ComputeChainObjfAndDeriv(chain_config_, den_graph_,
142  sup.supervision, nnet_output,
143  &tot_like, &tot_l2_term, &tot_weight,
144  (nnet_config_.compute_deriv ? &nnet_output_deriv :
145  NULL), (use_xent ? &xent_deriv : NULL));
146 
147  // note: in this context we don't want to apply 'sup.deriv_weights' because
148  // this code is used only in combination, where it's part of an L-BFGS
149  // optimization algorithm, and in that case if there is a mismatch between
150  // the computed objective function and the derivatives, it may cause errors
151  // in the optimization procedure such as early termination. (line search
152  // and conjugate gradient descent both rely on the derivatives being
153  // accurate, and don't fail gracefully if the derivatives are not accurate).
154 
155  ChainObjectiveInfo &totals = objf_info_[sup.name];
156  totals.tot_weight += tot_weight;
157  totals.tot_like += tot_like;
158  totals.tot_l2_term += tot_l2_term;
159 
161  computer->AcceptInput(sup.name, &nnet_output_deriv);
162 
163  if (use_xent) {
164  ChainObjectiveInfo &xent_totals = objf_info_[xent_name];
165  // this block computes the cross-entropy objective.
166  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(
167  xent_name);
168  // at this point, xent_deriv is posteriors derived from the numerator
169  // computation. note, xent_deriv has a factor of '.supervision.weight',
170  // but so does tot_weight.
171  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
172  xent_totals.tot_weight += tot_weight;
173  xent_totals.tot_like += xent_objf;
174  }
176  }
177 }
chain::ChainTrainingOptions chain_config_
kaldi::int32 int32
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_
float BaseFloat
Definition: kaldi-types.h:29
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
#define KALDI_ERR
Definition: kaldi-error.h:147
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ Reset()

void Reset ( )

Definition at line 79 of file nnet-chain-diagnostics.cc.

References NnetChainComputeProb::deriv_nnet_, NnetChainComputeProb::num_minibatches_processed_, NnetChainComputeProb::objf_info_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().

79  {
81  objf_info_.clear();
82  if (deriv_nnet_) {
83  ScaleNnet(0.0, deriv_nnet_);
85  }
86 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292

Member Data Documentation

◆ chain_config_

chain::ChainTrainingOptions chain_config_
private

◆ compiler_

CachingOptimizingCompiler compiler_
private

Definition at line 104 of file nnet-chain-diagnostics.h.

Referenced by NnetChainComputeProb::Compute().

◆ den_graph_

chain::DenominatorGraph den_graph_
private

◆ deriv_nnet_

◆ deriv_nnet_owned_

bool deriv_nnet_owned_
private

◆ nnet_

◆ nnet_config_

◆ num_minibatches_processed_

int32 num_minibatches_processed_
private

◆ objf_info_


The documentation for this class was generated from the following files: