25 int main(
int argc,
char *argv[]) {
27 using namespace kaldi;
30 typedef kaldi::int64 int64;
33 "Train nnet3 neural network parameters with backprop and stochastic\n" 34 "gradient descent. Minibatches are to be created by nnet3-merge-egs in\n" 35 "the input pipeline. This training program is single-threaded (best to\n" 36 "use it with a GPU); see nnet3-train-parallel for multi-threaded training\n" 37 "that is better suited to CPUs.\n" 39 "Usage: nnet3-train [options] <raw-model-in> <training-examples-in> <raw-model-out>\n" 42 "nnet3-train 1.raw 'ark:nnet3-merge-egs 1.egs ark:-|' 2.raw\n";
45 bool binary_write =
true;
46 std::string use_gpu =
"yes";
50 po.
Register(
"srand", &srand_seed,
"Seed for random number generator ");
51 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
53 "yes|no|optional|wait, only has effect if compiled with CUDA");
68 CuDevice::Instantiate().SelectGpuId(use_gpu);
71 std::string nnet_rxfilename = po.
GetArg(1),
72 examples_rspecifier = po.
GetArg(2),
73 nnet_wxfilename = po.
GetArg(3);
82 for (; !example_reader.
Done(); example_reader.
Next())
88 CuDevice::Instantiate().PrintProfile();
91 KALDI_LOG <<
"Wrote model to " << nnet_wxfilename;
93 }
catch(
const std::exception &e) {
94 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Register(OptionsItf *opts)
void Train(const NnetExample &eg)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool PrintTotalStats() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
This class is for single-threaded training of neural nets using standard objective functions such as ...
void RegisterCuAllocatorOptions(OptionsItf *po)
int main(int argc, char *argv[])