nnet-discriminative-training.cc
Go to the documentation of this file.
1 // nnet3/nnet-discriminative-training.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // Copyright 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
22 #include "nnet3/nnet-utils.h"
23 
24 namespace kaldi {
25 namespace nnet3 {
26 
28  const NnetDiscriminativeOptions &opts,
29  const TransitionModel &tmodel,
30  const VectorBase<BaseFloat> &priors,
31  Nnet *nnet):
32  opts_(opts), tmodel_(tmodel), log_priors_(priors),
33  nnet_(nnet),
34  compiler_(*nnet, opts_.nnet_config.optimize_config),
35  num_minibatches_processed_(0) {
37  ZeroComponentStats(nnet);
38  if (opts.nnet_config.momentum == 0.0 &&
39  opts.nnet_config.max_param_change == 0.0) {
40  delta_nnet_= NULL;
41  } else {
42  KALDI_ASSERT(opts.nnet_config.momentum >= 0.0 &&
43  opts.nnet_config.max_param_change >= 0.0);
44  delta_nnet_ = nnet_->Copy();
45  ScaleNnet(0.0, delta_nnet_);
46  }
47  if (opts.nnet_config.read_cache != "") {
48  bool binary;
49  Input ki;
50  if (ki.Open(opts.nnet_config.read_cache, &binary)) {
51  compiler_.ReadCache(ki.Stream(), binary);
52  KALDI_LOG << "Read computation cache from "
53  << opts.nnet_config.read_cache;
54  } else {
55  KALDI_WARN << "Could not open cached computation. "
56  "Probably this is the first training iteration.";
57  }
58  }
59  log_priors_.ApplyLog();
60 }
61 
62 
64  bool need_model_derivative = true;
65  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
66  bool use_xent_regularization = (opts_.discriminative_config.xent_regularize != 0.0);
67  ComputationRequest request;
68  GetDiscriminativeComputationRequest(*nnet_, eg, need_model_derivative,
69  nnet_config.store_component_stats,
70  use_xent_regularization,
71  need_model_derivative,
72  &request);
73  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
74 
75  NnetComputer computer(nnet_config.compute_config, *computation,
76  *nnet_,
77  (delta_nnet_ == NULL ? nnet_ : delta_nnet_));
78  // give the inputs to the computer object.
79  computer.AcceptInputs(*nnet_, eg.inputs);
80  computer.Run();
81 
82  this->ProcessOutputs(eg, &computer);
83  computer.Run();
84 
85  if (delta_nnet_ != NULL) {
86  BaseFloat scale = (1.0 - nnet_config.momentum);
87  if (nnet_config.max_param_change != 0.0) {
88  BaseFloat param_delta =
89  std::sqrt(DotProduct(*delta_nnet_, *delta_nnet_)) * scale;
90  if (param_delta > nnet_config.max_param_change) {
91  if (param_delta - param_delta != 0.0) {
92  KALDI_WARN << "Infinite parameter change, will not apply.";
93  ScaleNnet(0.0, delta_nnet_);
94  } else {
95  scale *= nnet_config.max_param_change / param_delta;
96  KALDI_LOG << "Parameter change too big: " << param_delta << " > "
97  << "--max-param-change=" << nnet_config.max_param_change
98  << ", scaling by "
99  << nnet_config.max_param_change / param_delta;
100  }
101  }
102  }
103  AddNnet(*delta_nnet_, scale, nnet_);
104  ScaleNnet(nnet_config.momentum, delta_nnet_);
105  }
106 }
107 
108 
110  NnetComputer *computer) {
111  // normally the eg will have just one output named 'output', but
112  // we don't assume this.
113  std::vector<NnetDiscriminativeSupervision>::const_iterator iter = eg.outputs.begin(),
114  end = eg.outputs.end();
115  for (; iter != end; ++iter) {
116  const NnetDiscriminativeSupervision &sup = *iter;
117  int32 node_index = nnet_->GetNodeIndex(sup.name);
118  if (node_index < 0 ||
119  !nnet_->IsOutputNode(node_index))
120  KALDI_ERR << "Network has no output named " << sup.name;
121 
122  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
123 
124  CuMatrix<BaseFloat> nnet_output_deriv(nnet_output.NumRows(),
125  nnet_output.NumCols(),
126  kUndefined);
127 
128  bool use_xent = (opts_.discriminative_config.xent_regularize != 0.0);
129  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
130  CuMatrix<BaseFloat> xent_deriv;
131  if (use_xent)
132  xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
133  kUndefined);
134 
136 
137  if (objf_info_.count(sup.name) == 0) {
138  objf_info_[sup.name].stats.Configure(opts_.discriminative_config);
139  objf_info_[sup.name].stats.Reset();
140  }
141 
144  sup.supervision, nnet_output,
145  &stats,
146  &nnet_output_deriv,
147  (use_xent ? &xent_deriv : NULL));
148 
149  if (use_xent) {
150  // this block computes the cross-entropy objective.
151  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(xent_name);
152  // at this point, xent_deriv is posteriors derived from the numerator
153  // computation. note, xent_objf has a factor of '.supervision.weight'
154  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
155  if (xent_objf != xent_objf) {
156  BaseFloat default_objf = -10;
157  xent_objf = default_objf;
158  }
159 
161  xent_stats.tot_t_weighted = stats.tot_t_weighted;
162  xent_stats.tot_objf = xent_objf;
163 
164  objf_info_[xent_name].UpdateStats(xent_name, "xent",
166  num_minibatches_processed_, xent_stats);
167  }
168 
169  if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) {
170  CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights);
171  nnet_output_deriv.MulRowsVec(cu_deriv_weights);
172  if (use_xent)
173  xent_deriv.MulRowsVec(cu_deriv_weights);
174  }
175 
176  computer->AcceptInput(sup.name, &nnet_output_deriv);
177 
181  stats);
182 
183  if (use_xent) {
185  computer->AcceptInput(xent_name, &xent_deriv);
186  }
187  }
188 }
189 
190 
192  unordered_map<std::string, DiscriminativeObjectiveFunctionInfo,
193  StringHasher>::const_iterator
194  iter = objf_info_.begin(),
195  end = objf_info_.end();
196  bool ans = false;
197  for (; iter != end; ++iter) {
198  const std::string &name = iter->first;
199  const DiscriminativeObjectiveFunctionInfo &info = iter->second;
200  bool ret = info.PrintTotalStats(name, opts_.discriminative_config.criterion);
201  ans = ans || ret;
202  }
203 
204  return ans;
205 }
206 
207 
209  const std::string &output_name,
210  const std::string &criterion,
211  int32 minibatches_per_phase,
212  int32 minibatch_counter,
213  discriminative::DiscriminativeObjectiveInfo this_minibatch_stats) {
214  int32 phase = minibatch_counter / minibatches_per_phase;
215  if (phase != current_phase) {
216  KALDI_ASSERT(phase == current_phase + 1); // or doesn't really make sense.
217  PrintStatsForThisPhase(output_name, criterion, minibatches_per_phase);
218  current_phase = phase;
219  stats_this_phase.Reset();
220  }
221  stats_this_phase.Add(this_minibatch_stats);
222  stats.Add(this_minibatch_stats);
223 }
224 
226  const std::string &output_name,
227  const std::string &criterion,
228  int32 minibatches_per_phase) const {
229  int32 start_minibatch = current_phase * minibatches_per_phase,
230  end_minibatch = start_minibatch + minibatches_per_phase - 1;
231 
232  BaseFloat objf = (stats_this_phase.TotalObjf(criterion) / stats_this_phase.tot_t_weighted);
233  KALDI_LOG << "Average objective function for '" << output_name
234  << "' for minibatches " << start_minibatch
235  << '-' << end_minibatch << " is " << objf
236  << " over " << stats_this_phase.tot_t_weighted << " frames.";
237 }
238 
240  const std::string &criterion) const {
241  BaseFloat objf = stats.TotalObjf(criterion) /stats.tot_t_weighted;
242 
243  double avg_gradients = (stats.tot_num_count + stats.tot_den_count) /
244  stats.tot_t_weighted;
245  KALDI_LOG << "Average num+den count of stats is " << avg_gradients
246  << " per frame, over "
247  << stats.tot_t_weighted << " frames.";
248  if (stats.tot_l2_term != 0.0) {
249  KALDI_LOG << "Average l2 norm of output per frame is "
250  << (stats.tot_l2_term / stats.tot_t_weighted) << " over "
251  << stats.tot_t_weighted << " frames.";
252  }
253 
254 
255  KALDI_LOG << "Overall average objective function for '" << name << "' is "
256  << objf << " over " << stats.tot_t_weighted << " frames.";
257  KALDI_LOG << "[this line is to be parsed by a script:] "
258  << criterion << "-per-frame="
259  << objf;
260  return (stats.tot_t_weighted != 0.0);
261 }
262 
263 
265  delete delta_nnet_;
266 
267  if (opts_.nnet_config.write_cache != "") {
270  }
271 }
272 
273 
274 } // namespace nnet3
275 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
bool Open(const std::string &rxfilename, bool *contents_binary=NULL)
Definition: kaldi-io-inl.h:26
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
A hashing function object for strings.
Definition: stl-utils.h:248
This file contains some miscellaneous functions dealing with class Nnet.
void Scale(Real value)
Definition: cu-matrix.cc:644
void ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer)
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
std::istream & Stream()
Definition: kaldi-io.cc:826
std::ostream & Stream()
Definition: kaldi-io.cc:701
discriminative::DiscriminativeOptions discriminative_config
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
void Train(const NnetDiscriminativeExample &eg)
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
Nnet * Copy() const
Definition: nnet-nnet.h:246
void ComputeDiscriminativeObjfAndDeriv(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
This function does forward-backward on the numerator and denominator lattices and computes derivates ...
#define KALDI_ERR
Definition: kaldi-error.h:147
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
#define KALDI_WARN
Definition: kaldi-error.h:150
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
BaseFloat DotProduct(const Nnet &nnet1, const Nnet &nnet2)
Returns dot product between two networks of the same structure (calls the DotProduct functions of the...
Definition: nnet-utils.cc:250
NnetDiscriminativeTrainer(const NnetDiscriminativeOptions &config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, Nnet *nnet)
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void ReadCache(std::istream &is, bool binary)
Matrix for CUDA computing.
Definition: matrix-common.h:69
void WriteCache(std::ostream &os, bool binary)
unordered_map< std::string, DiscriminativeObjectiveFunctionInfo, StringHasher > objf_info_
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
bool PrintTotalStats(const std::string &output_name, const std::string &criterion) const
std::vector< NnetDiscriminativeSupervision > outputs
&#39;outputs&#39; contains the sequence output supervision.
void GetDiscriminativeComputationRequest(const Nnet &nnet, const NnetDiscriminativeExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetDiscriminativeExample and produces a ComputationRequest.
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
void UpdateStats(const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase, int32 minibatch_counter, discriminative::DiscriminativeObjectiveInfo stats)
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
discriminative::DiscriminativeSupervision supervision
NnetComputeOptions compute_config
Definition: nnet-training.h:49
#define KALDI_LOG
Definition: kaldi-error.h:153
void PrintStatsForThisPhase(const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase) const
void MulRowsVec(const CuVectorBase< Real > &scale)
scale i&#39;th row by scale[i]
Definition: cu-matrix.cc:792
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).
Definition: nnet-utils.cc:349
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...