nnet3-train.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-train.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "nnet3/nnet-training.h"
24 
25 int main(int argc, char *argv[]) {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet3;
29  typedef kaldi::int32 int32;
30  typedef kaldi::int64 int64;
31 
32  const char *usage =
33  "Train nnet3 neural network parameters with backprop and stochastic\n"
34  "gradient descent. Minibatches are to be created by nnet3-merge-egs in\n"
35  "the input pipeline. This training program is single-threaded (best to\n"
36  "use it with a GPU); see nnet3-train-parallel for multi-threaded training\n"
37  "that is better suited to CPUs.\n"
38  "\n"
39  "Usage: nnet3-train [options] <raw-model-in> <training-examples-in> <raw-model-out>\n"
40  "\n"
41  "e.g.:\n"
42  "nnet3-train 1.raw 'ark:nnet3-merge-egs 1.egs ark:-|' 2.raw\n";
43 
44  int32 srand_seed = 0;
45  bool binary_write = true;
46  std::string use_gpu = "yes";
47  NnetTrainerOptions train_config;
48 
49  ParseOptions po(usage);
50  po.Register("srand", &srand_seed, "Seed for random number generator ");
51  po.Register("binary", &binary_write, "Write output in binary mode");
52  po.Register("use-gpu", &use_gpu,
53  "yes|no|optional|wait, only has effect if compiled with CUDA");
54 
55  train_config.Register(&po);
57 
58  po.Read(argc, argv);
59 
60  srand(srand_seed);
61 
62  if (po.NumArgs() != 3) {
63  po.PrintUsage();
64  exit(1);
65  }
66 
67 #if HAVE_CUDA==1
68  CuDevice::Instantiate().SelectGpuId(use_gpu);
69 #endif
70 
71  std::string nnet_rxfilename = po.GetArg(1),
72  examples_rspecifier = po.GetArg(2),
73  nnet_wxfilename = po.GetArg(3);
74 
75  Nnet nnet;
76  ReadKaldiObject(nnet_rxfilename, &nnet);
77 
78  NnetTrainer trainer(train_config, &nnet);
79 
80  SequentialNnetExampleReader example_reader(examples_rspecifier);
81 
82  for (; !example_reader.Done(); example_reader.Next())
83  trainer.Train(example_reader.Value());
84 
85  bool ok = trainer.PrintTotalStats();
86 
87 #if HAVE_CUDA==1
88  CuDevice::Instantiate().PrintProfile();
89 #endif
90  WriteKaldiObject(nnet, nnet_wxfilename, binary_write);
91  KALDI_LOG << "Wrote model to " << nnet_wxfilename;
92  return (ok ? 0 : 1);
93  } catch(const std::exception &e) {
94  std::cerr << e.what() << '\n';
95  return -1;
96  }
97 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Register(OptionsItf *opts)
Definition: nnet-training.h:63
kaldi::int32 int32
void Train(const NnetExample &eg)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
This class is for single-threaded training of neural nets using standard objective functions such as ...
void RegisterCuAllocatorOptions(OptionsItf *po)
Definition: cu-allocator.h:87
#define KALDI_LOG
Definition: kaldi-error.h:153
int main(int argc, char *argv[])
Definition: nnet3-train.cc:25