nnet3-combine.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-combine.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2017 Yiming Wang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "nnet3/nnet-utils.h"
24 #include "nnet3/nnet-compute.h"
25 #include "nnet3/nnet-diagnostics.h"
26 
27 
28 namespace kaldi {
29 namespace nnet3 {
30 
31 // Computes and returns the objective function for the examples in 'egs' given
32 // the model in 'nnet'. If either of batchnorm/dropout test modes is true, we
33 // make a copy of 'nnet', set test modes on that and evaluate its objective.
34 // Note: the object that prob_computer->nnet_ refers to should be 'nnet'.
35 double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode,
36  const std::vector<NnetExample> &egs, const Nnet &nnet,
37  NnetComputeProb *prob_computer) {
38  if (batchnorm_test_mode || dropout_test_mode) {
39  Nnet nnet_copy(nnet);
40  if (batchnorm_test_mode)
41  SetBatchnormTestMode(true, &nnet_copy);
42  if (dropout_test_mode)
43  SetDropoutTestMode(true, &nnet_copy);
44  NnetComputeProbOptions compute_prob_opts;
45  NnetComputeProb prob_computer_test(compute_prob_opts, nnet_copy);
46  return ComputeObjf(false, false, egs, nnet_copy, &prob_computer_test);
47  } else {
48  prob_computer->Reset();
49  std::vector<NnetExample>::const_iterator iter = egs.begin(),
50  end = egs.end();
51  for (; iter != end; ++iter)
52  prob_computer->Compute(*iter);
53  double tot_weights,
54  tot_objf = prob_computer->GetTotalObjective(&tot_weights);
55  KALDI_ASSERT(tot_weights > 0.0);
56  // inf/nan tot_objf->return -inf objective.
57  if (!(tot_objf == tot_objf && tot_objf - tot_objf == 0))
58  return -std::numeric_limits<double>::infinity();
59  // we prefer to deal with normalized objective functions.
60  return tot_objf / tot_weights;
61  }
62 }
63 
64 // Updates moving average over num_models nnets, given the average over
65 // previous (num_models - 1) nnets, and the new nnet.
67  const Nnet &nnet, Nnet *moving_average_nnet) {
68  KALDI_ASSERT(NumParameters(nnet) == NumParameters(*moving_average_nnet));
69  ScaleNnet((num_models - 1.0) / num_models, moving_average_nnet);
70  AddNnet(nnet, 1.0 / num_models, moving_average_nnet);
71 }
72 
73 }
74 }
75 
76 int main(int argc, char *argv[]) {
77  try {
78  using namespace kaldi;
79  using namespace kaldi::nnet3;
80  typedef kaldi::int32 int32;
81  typedef kaldi::int64 int64;
82 
83  const char *usage =
84  "Using a subset of training or held-out examples, compute the average\n"
85  "over the first n nnet3 models where we maxize the objective function\n"
86  "for n. Note that the order of models has been reversed before\n"
87  "being fed into this binary. So we are actually combining last n models.\n"
88  "Inputs and outputs are 'raw' nnets.\n"
89  "\n"
90  "Usage: nnet3-combine [options] <nnet-in1> <nnet-in2> ... <nnet-inN> <valid-examples-in> <nnet-out>\n"
91  "\n"
92  "e.g.:\n"
93  " nnet3-combine 1.1.raw 1.2.raw 1.3.raw ark:valid.egs 2.raw\n";
94 
95  bool binary_write = true;
96  int32 max_objective_evaluations = 30;
97  bool batchnorm_test_mode = false,
98  dropout_test_mode = true;
99  std::string use_gpu = "yes";
100 
101  ParseOptions po(usage);
102  po.Register("binary", &binary_write, "Write output in binary mode");
103  po.Register("max-objective-evaluations", &max_objective_evaluations, "The "
104  "maximum number of objective evaluations in order to figure "
105  "out the best number of models to combine. It helps to speedup "
106  "if the number of models provided to this binary is quite "
107  "large (e.g. several hundred).");
108  po.Register("batchnorm-test-mode", &batchnorm_test_mode,
109  "If true, set test-mode to true on any BatchNormComponents "
110  "while evaluating objectives.");
111  po.Register("dropout-test-mode", &dropout_test_mode,
112  "If true, set test-mode to true on any DropoutComponents and "
113  "DropoutMaskComponents while evaluating objectives.");
114  po.Register("use-gpu", &use_gpu,
115  "yes|no|optional|wait, only has effect if compiled with CUDA");
116 
117  po.Read(argc, argv);
118 
119  if (po.NumArgs() < 3) {
120  po.PrintUsage();
121  exit(1);
122  }
123 
124 #if HAVE_CUDA==1
125  CuDevice::Instantiate().SelectGpuId(use_gpu);
126 #endif
127 
128  std::string
129  nnet_rxfilename = po.GetArg(1),
130  valid_examples_rspecifier = po.GetArg(po.NumArgs() - 1),
131  nnet_wxfilename = po.GetArg(po.NumArgs());
132 
133  Nnet nnet;
134  ReadKaldiObject(nnet_rxfilename, &nnet);
135  Nnet moving_average_nnet(nnet), best_nnet(nnet);
136  NnetComputeProbOptions compute_prob_opts;
137  NnetComputeProb prob_computer(compute_prob_opts, moving_average_nnet);
138 
139  std::vector<NnetExample> egs;
140  egs.reserve(10000); // reserve a lot of space to minimize the chance of
141  // reallocation.
142 
143  { // This block adds training examples to "egs".
144  SequentialNnetExampleReader example_reader(
145  valid_examples_rspecifier);
146  for (; !example_reader.Done(); example_reader.Next())
147  egs.push_back(example_reader.Value());
148  KALDI_LOG << "Read " << egs.size() << " examples.";
149  KALDI_ASSERT(!egs.empty());
150  }
151 
152  // first evaluates the objective using the last model.
153  int32 best_num_to_combine = 1;
154  double
155  init_objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode,
156  egs, moving_average_nnet, &prob_computer),
157  best_objf = init_objf;
158  KALDI_LOG << "objective function using the last model is " << init_objf;
159 
160  int32 num_nnets = po.NumArgs() - 2;
161  // then each time before we re-evaluate the objective function, we will add
162  // num_to_add models to the moving average.
163  int32 num_to_add = (num_nnets + max_objective_evaluations - 1) /
164  max_objective_evaluations;
165  for (int32 n = 1; n < num_nnets; n++) {
166  ReadKaldiObject(po.GetArg(1 + n), &nnet);
167  // updates the moving average
168  UpdateNnetMovingAverage(n + 1, nnet, &moving_average_nnet);
169  // evaluates the objective everytime after adding num_to_add model or
170  // all the models to the moving average.
171  if ((n - 1) % num_to_add == num_to_add - 1 || n == num_nnets - 1) {
172  double objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode,
173  egs, moving_average_nnet, &prob_computer);
174  KALDI_LOG << "Combining last " << n + 1
175  << " models, objective function is " << objf;
176  if (objf > best_objf) {
177  best_objf = objf;
178  best_nnet = moving_average_nnet;
179  best_num_to_combine = n + 1;
180  }
181  }
182  }
183  KALDI_LOG << "Combining " << best_num_to_combine
184  << " nnets, objective function changed from " << init_objf
185  << " to " << best_objf;
186 
187  if (HasBatchnorm(nnet))
188  RecomputeStats(egs, &best_nnet);
189 
190 #if HAVE_CUDA==1
191  CuDevice::Instantiate().PrintProfile();
192 #endif
193 
194  WriteKaldiObject(best_nnet, nnet_wxfilename, binary_write);
195  KALDI_LOG << "Finished combining neural nets, wrote model to "
196  << nnet_wxfilename;
197  } catch(const std::exception &e) {
198  std::cerr << e.what() << '\n';
199  return -1;
200  }
201 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void UpdateNnetMovingAverage(int32 num_models, const Nnet &nnet, Nnet *moving_average_nnet)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Compute(const NnetExample &eg)
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
kaldi::int32 int32
This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics.
double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode, const std::vector< NnetExample > &egs, const Nnet &nnet, NnetComputeProb *prob_computer)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
This file contains some miscellaneous functions dealing with class Nnet.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
Definition: nnet-utils.cc:359
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
bool HasBatchnorm(const Nnet &nnet)
Returns true if nnet has at least one component of type BatchNormComponent.
Definition: nnet-utils.cc:527
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
double GetTotalObjective(double *tot_weight) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
int main(int argc, char *argv[])
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).
Definition: nnet-utils.cc:349