nnet-train-ensemble.cc File Reference
Include dependency graph for nnet-train-ensemble.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file nnet-train-ensemble.cc.

References SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), KALDI_LOG, rnnlm::n, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), TransitionModel::Read(), NnetEnsembleTrainerConfig::Register(), ParseOptions::Register(), Output::Stream(), Input::Stream(), NnetEnsembleTrainer::TrainOnExample(), SequentialTableReader< Holder >::Value(), and TransitionModel::Write().

28  {
29  try {
30  using namespace kaldi;
31  using namespace kaldi::nnet2;
32  typedef kaldi::int32 int32;
33  typedef kaldi::int64 int64;
34 
35  const char *usage =
36  "Train an ensemble of neural networks with backprop and stochastic\n"
37  "gradient descent using minibatches. Modified version of nnet-train-simple.\n"
38  "Implements parallel gradient descent with a term that encourages the nnets to\n"
39  "produce similar outputs.\n"
40  "\n"
41  "Usage: nnet-train-ensemble [options] <model-in-1> <model-in-2> ... <model-in-n>"
42  " <training-examples-in> <model-out-1> <model-out-2> ... <model-out-n> \n"
43  "\n"
44  "e.g.:\n"
45  " nnet-train-ensemble 1.1.nnet 2.1.nnet ark:egs.ark 2.1.nnet 2.2.nnet \n";
46 
47  bool binary_write = true;
48  bool zero_stats = true;
49  int32 srand_seed = 0;
50  std::string use_gpu = "yes";
51  NnetEnsembleTrainerConfig train_config;
52 
53  ParseOptions po(usage);
54  po.Register("binary", &binary_write, "Write output in binary mode");
55  po.Register("zero-stats", &zero_stats, "If true, zero occupation "
56  "counts stored with the neural net (only affects mixing up).");
57  po.Register("srand", &srand_seed, "Seed for random number generator "
58  "(relevant if you have layers of type AffineComponentPreconditioned "
59  "with l2-penalty != 0.0");
60  po.Register("use-gpu", &use_gpu,
61  "yes|no|optional|wait, only has effect if compiled with CUDA");
62 
63  train_config.Register(&po);
64 
65  po.Read(argc, argv);
66 
67  if (po.NumArgs() <= 3) {
68  po.PrintUsage();
69  exit(1);
70  }
71  srand(srand_seed);
72 
73 #if HAVE_CUDA==1
74  CuDevice::Instantiate().SelectGpuId(use_gpu);
75 #endif
76 
77  int32 num_nnets = (po.NumArgs() - 1) / 2;
78  std::string nnet_rxfilename = po.GetArg(1);
79  std::string examples_rspecifier = po.GetArg(num_nnets + 1);
80 
81  std::string nnet1_rxfilename = po.GetArg(1);
82 
83  TransitionModel trans_model;
84  std::vector<AmNnet> am_nnets(num_nnets);
85  {
86  bool binary_read;
87  Input ki(nnet1_rxfilename, &binary_read);
88  trans_model.Read(ki.Stream(), binary_read);
89  KALDI_LOG << nnet1_rxfilename;
90  am_nnets[0].Read(ki.Stream(), binary_read);
91  }
92 
93  std::vector<Nnet*> nnets(num_nnets);
94  nnets[0] = &(am_nnets[0].GetNnet());
95 
96  for (int32 n = 1; n < num_nnets; n++) {
97  TransitionModel trans_model;
98  bool binary_read;
99  Input ki(po.GetArg(1 + n), &binary_read);
100  trans_model.Read(ki.Stream(), binary_read);
101  am_nnets[n].Read(ki.Stream(), binary_read);
102  nnets[n] = &am_nnets[n].GetNnet();
103  }
104 
105 
106  int64 num_examples = 0;
107 
108  {
109  if (zero_stats) {
110  for (int32 n = 1; n < num_nnets; n++)
111  nnets[n]->ZeroStats();
112  }
113  { // want to make sure this object deinitializes before
114  // we write the model, as it does something in the destructor.
115  NnetEnsembleTrainer trainer(train_config,
116  nnets);
117 
118  SequentialNnetExampleReader example_reader(examples_rspecifier);
119 
120  for (; !example_reader.Done(); example_reader.Next(), num_examples++)
121  trainer.TrainOnExample(example_reader.Value()); // It all happens here!
122  }
123 
124  {
125  for (int32 n = 0; n < num_nnets; n++) {
126  Output ko(po.GetArg(po.NumArgs() - num_nnets + n + 1), binary_write);
127  trans_model.Write(ko.Stream(), binary_write);
128  am_nnets[n].Write(ko.Stream(), binary_write);
129  }
130  }
131  }
132 #if HAVE_CUDA==1
133  CuDevice::Instantiate().PrintProfile();
134 #endif
135 
136  KALDI_LOG << "Finished training, processed " << num_examples
137  << " training examples.";
138  return (num_examples == 0 ? 1 : 0);
139  } catch(const std::exception &e) {
140  std::cerr << e.what() << '\n';
141  return -1;
142  }
143 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Write(std::ostream &os, bool binary) const
#define KALDI_LOG
Definition: kaldi-error.h:153