27 int main(
int argc,
char *argv[]) {
29 using namespace kaldi;
32 typedef kaldi::int64 int64;
35 "Train the neural network parameters with backprop and stochastic\n" 36 "gradient descent using minibatches. Training examples would be\n" 37 "produced by nnet-get-egs.\n" 39 "Usage: nnet-train-simple [options] <model-in> <training-examples-in> <model-out>\n" 42 "nnet-train-simple 1.nnet ark:1.egs 2.nnet\n";
44 bool binary_write =
true;
45 bool zero_stats =
true;
47 std::string use_gpu =
"yes";
51 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
52 po.
Register(
"zero-stats", &zero_stats,
"If true, zero occupation " 53 "counts stored with the neural net (only affects mixing up).");
54 po.
Register(
"srand", &srand_seed,
"Seed for random number generator " 55 "(relevant if you have layers of type AffineComponentPreconditioned " 56 "with l2-penalty != 0.0");
58 "yes|no|optional|wait, only has effect if compiled with CUDA");
71 CuDevice::Instantiate().SelectGpuId(use_gpu);
74 std::string nnet_rxfilename = po.
GetArg(1),
75 examples_rspecifier = po.
GetArg(2),
76 nnet_wxfilename = po.
GetArg(3);
85 Input ki(nnet_rxfilename, &binary_read);
98 Output ko(nnet_wxfilename, binary_write);
104 CuDevice::Instantiate().PrintProfile();
107 KALDI_LOG <<
"Finished training, processed " << num_examples
108 <<
" training examples. Wrote model to " 110 return (num_examples == 0 ? 1 : 0);
111 }
catch(
const std::exception &e) {
112 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void Register(OptionsItf *opts)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
int64 TrainNnetSimple(const NnetSimpleTrainerConfig &config, Nnet *nnet, SequentialNnetExampleReader *reader, double *tot_weight_ptr, double *tot_logprob_ptr)
Train on all the examples it can read from the reader.
int main(int argc, char *argv[])
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
const Nnet & GetNnet() const