nnet-train-ensemble.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-train-ensemble.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 // 2014 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "hmm/transition-model.h"
25 #include "nnet2/am-nnet.h"
26 
27 
28 int main(int argc, char *argv[]) {
29  try {
30  using namespace kaldi;
31  using namespace kaldi::nnet2;
32  typedef kaldi::int32 int32;
33  typedef kaldi::int64 int64;
34 
35  const char *usage =
36  "Train an ensemble of neural networks with backprop and stochastic\n"
37  "gradient descent using minibatches. Modified version of nnet-train-simple.\n"
38  "Implements parallel gradient descent with a term that encourages the nnets to\n"
39  "produce similar outputs.\n"
40  "\n"
41  "Usage: nnet-train-ensemble [options] <model-in-1> <model-in-2> ... <model-in-n>"
42  " <training-examples-in> <model-out-1> <model-out-2> ... <model-out-n> \n"
43  "\n"
44  "e.g.:\n"
45  " nnet-train-ensemble 1.1.nnet 2.1.nnet ark:egs.ark 2.1.nnet 2.2.nnet \n";
46 
47  bool binary_write = true;
48  bool zero_stats = true;
49  int32 srand_seed = 0;
50  std::string use_gpu = "yes";
51  NnetEnsembleTrainerConfig train_config;
52 
53  ParseOptions po(usage);
54  po.Register("binary", &binary_write, "Write output in binary mode");
55  po.Register("zero-stats", &zero_stats, "If true, zero occupation "
56  "counts stored with the neural net (only affects mixing up).");
57  po.Register("srand", &srand_seed, "Seed for random number generator "
58  "(relevant if you have layers of type AffineComponentPreconditioned "
59  "with l2-penalty != 0.0");
60  po.Register("use-gpu", &use_gpu,
61  "yes|no|optional|wait, only has effect if compiled with CUDA");
62 
63  train_config.Register(&po);
64 
65  po.Read(argc, argv);
66 
67  if (po.NumArgs() <= 3) {
68  po.PrintUsage();
69  exit(1);
70  }
71  srand(srand_seed);
72 
73 #if HAVE_CUDA==1
74  CuDevice::Instantiate().SelectGpuId(use_gpu);
75 #endif
76 
77  int32 num_nnets = (po.NumArgs() - 1) / 2;
78  std::string nnet_rxfilename = po.GetArg(1);
79  std::string examples_rspecifier = po.GetArg(num_nnets + 1);
80 
81  std::string nnet1_rxfilename = po.GetArg(1);
82 
83  TransitionModel trans_model;
84  std::vector<AmNnet> am_nnets(num_nnets);
85  {
86  bool binary_read;
87  Input ki(nnet1_rxfilename, &binary_read);
88  trans_model.Read(ki.Stream(), binary_read);
89  KALDI_LOG << nnet1_rxfilename;
90  am_nnets[0].Read(ki.Stream(), binary_read);
91  }
92 
93  std::vector<Nnet*> nnets(num_nnets);
94  nnets[0] = &(am_nnets[0].GetNnet());
95 
96  for (int32 n = 1; n < num_nnets; n++) {
97  TransitionModel trans_model;
98  bool binary_read;
99  Input ki(po.GetArg(1 + n), &binary_read);
100  trans_model.Read(ki.Stream(), binary_read);
101  am_nnets[n].Read(ki.Stream(), binary_read);
102  nnets[n] = &am_nnets[n].GetNnet();
103  }
104 
105 
106  int64 num_examples = 0;
107 
108  {
109  if (zero_stats) {
110  for (int32 n = 1; n < num_nnets; n++)
111  nnets[n]->ZeroStats();
112  }
113  { // want to make sure this object deinitializes before
114  // we write the model, as it does something in the destructor.
115  NnetEnsembleTrainer trainer(train_config,
116  nnets);
117 
118  SequentialNnetExampleReader example_reader(examples_rspecifier);
119 
120  for (; !example_reader.Done(); example_reader.Next(), num_examples++)
121  trainer.TrainOnExample(example_reader.Value()); // It all happens here!
122  }
123 
124  {
125  for (int32 n = 0; n < num_nnets; n++) {
126  Output ko(po.GetArg(po.NumArgs() - num_nnets + n + 1), binary_write);
127  trans_model.Write(ko.Stream(), binary_write);
128  am_nnets[n].Write(ko.Stream(), binary_write);
129  }
130  }
131  }
132 #if HAVE_CUDA==1
133  CuDevice::Instantiate().PrintProfile();
134 #endif
135 
136  KALDI_LOG << "Finished training, processed " << num_examples
137  << " training examples.";
138  return (num_examples == 0 ? 1 : 0);
139  } catch(const std::exception &e) {
140  std::cerr << e.what() << '\n';
141  return -1;
142  }
143 }
144 
145 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
int main(int argc, char *argv[])
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void TrainOnExample(const NnetExample &value)
TrainOnExample will take the example and add it to a buffer; if we&#39;ve reached the minibatch size it w...
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
#define KALDI_LOG
Definition: kaldi-error.h:153