nnet3-train.cc File Reference
Include dependency graph for nnet3-train.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 25 of file nnet3-train.cc.

References SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), KALDI_LOG, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), NnetTrainer::PrintTotalStats(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), NnetTrainerOptions::Register(), ParseOptions::Register(), kaldi::RegisterCuAllocatorOptions(), NnetTrainer::Train(), SequentialTableReader< Holder >::Value(), and kaldi::WriteKaldiObject().

25  {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet3;
29  typedef kaldi::int32 int32;
30  typedef kaldi::int64 int64;
31 
32  const char *usage =
33  "Train nnet3 neural network parameters with backprop and stochastic\n"
34  "gradient descent. Minibatches are to be created by nnet3-merge-egs in\n"
35  "the input pipeline. This training program is single-threaded (best to\n"
36  "use it with a GPU); see nnet3-train-parallel for multi-threaded training\n"
37  "that is better suited to CPUs.\n"
38  "\n"
39  "Usage: nnet3-train [options] <raw-model-in> <training-examples-in> <raw-model-out>\n"
40  "\n"
41  "e.g.:\n"
42  "nnet3-train 1.raw 'ark:nnet3-merge-egs 1.egs ark:-|' 2.raw\n";
43 
44  int32 srand_seed = 0;
45  bool binary_write = true;
46  std::string use_gpu = "yes";
47  NnetTrainerOptions train_config;
48 
49  ParseOptions po(usage);
50  po.Register("srand", &srand_seed, "Seed for random number generator ");
51  po.Register("binary", &binary_write, "Write output in binary mode");
52  po.Register("use-gpu", &use_gpu,
53  "yes|no|optional|wait, only has effect if compiled with CUDA");
54 
55  train_config.Register(&po);
57 
58  po.Read(argc, argv);
59 
60  srand(srand_seed);
61 
62  if (po.NumArgs() != 3) {
63  po.PrintUsage();
64  exit(1);
65  }
66 
67 #if HAVE_CUDA==1
68  CuDevice::Instantiate().SelectGpuId(use_gpu);
69 #endif
70 
71  std::string nnet_rxfilename = po.GetArg(1),
72  examples_rspecifier = po.GetArg(2),
73  nnet_wxfilename = po.GetArg(3);
74 
75  Nnet nnet;
76  ReadKaldiObject(nnet_rxfilename, &nnet);
77 
78  NnetTrainer trainer(train_config, &nnet);
79 
80  SequentialNnetExampleReader example_reader(examples_rspecifier);
81 
82  for (; !example_reader.Done(); example_reader.Next())
83  trainer.Train(example_reader.Value());
84 
85  bool ok = trainer.PrintTotalStats();
86 
87 #if HAVE_CUDA==1
88  CuDevice::Instantiate().PrintProfile();
89 #endif
90  WriteKaldiObject(nnet, nnet_wxfilename, binary_write);
91  KALDI_LOG << "Wrote model to " << nnet_wxfilename;
92  return (ok ? 0 : 1);
93  } catch(const std::exception &e) {
94  std::cerr << e.what() << '\n';
95  return -1;
96  }
97 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: nnet-training.h:63
kaldi::int32 int32
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
This class is for single-threaded training of neural nets using standard objective functions such as ...
void RegisterCuAllocatorOptions(OptionsItf *po)
Definition: cu-allocator.h:87
#define KALDI_LOG
Definition: kaldi-error.h:153