nnet-chain-diagnostics.cc
Go to the documentation of this file.
1 // nnet3/nnet-chain-diagnostics.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
21 #include "nnet3/nnet-utils.h"
22 
23 namespace kaldi {
24 namespace nnet3 {
25 
27  const NnetComputeProbOptions &nnet_config,
28  const chain::ChainTrainingOptions &chain_config,
29  const fst::StdVectorFst &den_fst,
30  const Nnet &nnet):
31  nnet_config_(nnet_config),
32  chain_config_(chain_config),
33  den_graph_(den_fst, nnet.OutputDim("output")),
34  nnet_(nnet),
35  compiler_(nnet, nnet_config_.optimize_config, nnet_config_.compiler_config),
36  deriv_nnet_owned_(true),
37  deriv_nnet_(NULL),
38  num_minibatches_processed_(0) {
40  deriv_nnet_ = new Nnet(nnet_);
41  ScaleNnet(0.0, deriv_nnet_);
42  SetNnetAsGradient(deriv_nnet_); // force simple update
44  KALDI_ERR << "If you set store_component_stats == true and "
45  << "compute_deriv == false, use the other constructor.";
46  }
47 }
48 
49 
51  const NnetComputeProbOptions &nnet_config,
52  const chain::ChainTrainingOptions &chain_config,
53  const fst::StdVectorFst &den_fst,
54  Nnet *nnet):
55  nnet_config_(nnet_config),
56  chain_config_(chain_config),
57  den_graph_(den_fst, nnet->OutputDim("output")),
58  nnet_(*nnet),
59  compiler_(*nnet, nnet_config_.optimize_config, nnet_config_.compiler_config),
60  deriv_nnet_owned_(false),
61  deriv_nnet_(nnet),
63  KALDI_ASSERT(den_graph_.NumPdfs() > 0);
64  KALDI_ASSERT(nnet_config.store_component_stats && !nnet_config.compute_deriv);
65 }
66 
67 
70  KALDI_ERR << "GetDeriv() called when no derivatives were requested.";
71  return *deriv_nnet_;
72 }
73 
76  delete deriv_nnet_; // delete does nothing if pointer is NULL.
77 }
78 
81  objf_info_.clear();
82  if (deriv_nnet_) {
83  ScaleNnet(0.0, deriv_nnet_);
85  }
86 }
87 
89  bool need_model_derivative = nnet_config_.compute_deriv,
90  store_component_stats = nnet_config_.store_component_stats;
91  ComputationRequest request;
92  // if the options specify cross-entropy regularization, we'll be computing
93  // this objective (not interpolated with the regular objective-- we give it a
94  // separate name), but currently we won't make it contribute to the
95  // derivative-- we just compute the derivative of the regular output.
96  // This is because in the place where we use the derivative (the
97  // model-combination code) we decided to keep it simple and just use the
98  // regular objective.
99  bool use_xent_regularization = (chain_config_.xent_regularize != 0.0),
100  use_xent_derivative = false;
101  GetChainComputationRequest(nnet_, chain_eg, need_model_derivative,
102  store_component_stats, use_xent_regularization,
103  use_xent_derivative, &request);
104  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
105  NnetComputer computer(nnet_config_.compute_config, *computation,
106  nnet_, deriv_nnet_);
107  // give the inputs to the computer object.
108  computer.AcceptInputs(nnet_, chain_eg.inputs);
109  computer.Run();
110  this->ProcessOutputs(chain_eg, &computer);
112  computer.Run();
113 }
114 
116  NnetComputer *computer) {
117  // There will normally be just one output here, named 'output',
118  // but the code is more general than this.
119  std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(),
120  end = eg.outputs.end();
121  for (; iter != end; ++iter) {
122  const NnetChainSupervision &sup = *iter;
123  int32 node_index = nnet_.GetNodeIndex(sup.name);
124  if (node_index < 0 ||
125  !nnet_.IsOutputNode(node_index))
126  KALDI_ERR << "Network has no output named " << sup.name;
127 
128  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
129  bool use_xent = (chain_config_.xent_regularize != 0.0);
130  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
131  CuMatrix<BaseFloat> nnet_output_deriv, xent_deriv;
133  nnet_output_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
134  kUndefined);
135  if (use_xent)
136  xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
137  kUndefined);
138 
139  BaseFloat tot_like, tot_l2_term, tot_weight;
140 
141  ComputeChainObjfAndDeriv(chain_config_, den_graph_,
142  sup.supervision, nnet_output,
143  &tot_like, &tot_l2_term, &tot_weight,
144  (nnet_config_.compute_deriv ? &nnet_output_deriv :
145  NULL), (use_xent ? &xent_deriv : NULL));
146 
147  // note: in this context we don't want to apply 'sup.deriv_weights' because
148  // this code is used only in combination, where it's part of an L-BFGS
149  // optimization algorithm, and in that case if there is a mismatch between
150  // the computed objective function and the derivatives, it may cause errors
151  // in the optimization procedure such as early termination. (line search
152  // and conjugate gradient descent both rely on the derivatives being
153  // accurate, and don't fail gracefully if the derivatives are not accurate).
154 
155  ChainObjectiveInfo &totals = objf_info_[sup.name];
156  totals.tot_weight += tot_weight;
157  totals.tot_like += tot_like;
158  totals.tot_l2_term += tot_l2_term;
159 
161  computer->AcceptInput(sup.name, &nnet_output_deriv);
162 
163  if (use_xent) {
164  ChainObjectiveInfo &xent_totals = objf_info_[xent_name];
165  // this block computes the cross-entropy objective.
166  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(
167  xent_name);
168  // at this point, xent_deriv is posteriors derived from the numerator
169  // computation. note, xent_deriv has a factor of '.supervision.weight',
170  // but so does tot_weight.
171  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
172  xent_totals.tot_weight += tot_weight;
173  xent_totals.tot_like += xent_objf;
174  }
176  }
177 }
178 
180  bool ans = false;
181  unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
182  iter, end;
183  iter = objf_info_.begin();
184  end = objf_info_.end();
185  for (; iter != end; ++iter) {
186  const std::string &name = iter->first;
187  int32 node_index = nnet_.GetNodeIndex(name);
188  KALDI_ASSERT(node_index >= 0);
189  const ChainObjectiveInfo &info = iter->second;
190  BaseFloat like = (info.tot_like / info.tot_weight),
191  l2_term = (info.tot_l2_term / info.tot_weight),
192  tot_objf = like + l2_term;
193  if (info.tot_l2_term == 0.0) {
194  KALDI_LOG << "Overall log-probability for '"
195  << name << "' is "
196  << like << " per frame"
197  << ", over " << info.tot_weight << " frames.";
198  } else {
199  KALDI_LOG << "Overall log-probability for '"
200  << name << "' is "
201  << like << " + " << l2_term << " = " << tot_objf << " per frame"
202  << ", over " << info.tot_weight << " frames.";
203  }
204  if (info.tot_weight > 0)
205  ans = true;
206  }
207  return ans;
208 }
209 
210 
212  const std::string &output_name) const {
213  unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
214  iter = objf_info_.find(output_name);
215  if (iter != objf_info_.end())
216  return &(iter->second);
217  else
218  return NULL;
219 }
220 
221 double NnetChainComputeProb::GetTotalObjective(double *total_weight) const {
222  double tot_objectives = 0.0;
223  double tot_weight = 0.0;
224  unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
225  iter = objf_info_.begin(), end = objf_info_.end();
226  for (; iter != end; ++iter) {
227  tot_objectives += iter->second.tot_like + iter->second.tot_l2_term;
228  tot_weight += iter->second.tot_weight;
229  }
230 
231  if (total_weight) *total_weight = tot_weight;
232  return tot_objectives;
233 }
234 
235 static bool HasXentOutputs(const Nnet &nnet) {
236  const std::vector<std::string> node_names = nnet.GetNodeNames();
237  for (std::vector<std::string>::const_iterator it = node_names.begin();
238  it != node_names.end(); ++it) {
239  int32 node_index = nnet.GetNodeIndex(*it);
240  if (nnet.IsOutputNode(node_index) &&
241  it->find("-xent") != std::string::npos) {
242  return true;
243  }
244  }
245  return false;
246 }
247 
248 void RecomputeStats(const std::vector<NnetChainExample> &egs,
249  const chain::ChainTrainingOptions &chain_config_in,
250  const fst::StdVectorFst &den_fst,
251  Nnet *nnet) {
252  KALDI_LOG << "Recomputing stats on nnet (affects batch-norm)";
253  chain::ChainTrainingOptions chain_config(chain_config_in);
254  if (HasXentOutputs(*nnet) &&
255  chain_config.xent_regularize == 0) {
256  // this forces it to compute the output for xent outputs,
257  // usually 'output-xent', which
258  // means that we'll be computing batch-norm stats for any
259  // components in that branch that have batch-norm.
260  chain_config.xent_regularize = 0.1;
261  }
262 
263  ZeroComponentStats(nnet);
264  NnetComputeProbOptions nnet_config;
265  nnet_config.store_component_stats = true;
266  NnetChainComputeProb prob_computer(nnet_config, chain_config, den_fst, nnet);
267  for (size_t i = 0; i < egs.size(); i++)
268  prob_computer.Compute(egs[i]);
269  prob_computer.PrintTotalStats();
270  KALDI_LOG << "Done recomputing stats.";
271 }
272 
273 
274 
275 } // namespace nnet3
276 } // namespace kaldi
double GetTotalObjective(double *tot_weight) const
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
chain::ChainTrainingOptions chain_config_
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
This class is for computing objective-function values in a nnet3+chain setup, for diagnostics...
chain::Supervision supervision
The supervision object, containing the FST.
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
kaldi::int32 int32
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
fst::StdVectorFst StdVectorFst
static bool HasXentOutputs(const Nnet &nnet)
std::string name
the name of the output in the neural net; in simple setups it will just be "output".
This file contains some miscellaneous functions dealing with class Nnet.
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
NnetChainComputeProb(const NnetComputeProbOptions &nnet_config, const chain::ChainTrainingOptions &chain_config, const fst::StdVectorFst &den_fst, const Nnet &nnet)
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
std::vector< NnetChainSupervision > outputs
&#39;outputs&#39; contains the chain output supervision.
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
#define KALDI_ERR
Definition: kaldi-error.h:147
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
Matrix for CUDA computing.
Definition: matrix-common.h:69
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Compute(const NnetChainExample &chain_eg)
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.
#define KALDI_LOG
Definition: kaldi-error.h:153
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50
const ChainObjectiveInfo * GetObjective(const std::string &output_name) const
const std::vector< std::string > & GetNodeNames() const
returns vector of node names (needed by some parsing code, for instance).
Definition: nnet-nnet.cc:63
void ProcessOutputs(const NnetChainExample &chain_eg, NnetComputer *computer)
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...