27 int main(
int argc,
char *argv[]) {
29 using namespace kaldi;
32 typedef kaldi::int64 int64;
35 "Train the neural network parameters with backprop and stochastic\n" 36 "gradient descent using minibatches. As nnet-train-simple, but\n" 37 "uses multiple threads in a Hogwild type of update (for CPU, not GPU).\n" 39 "Usage: nnet-train-parallel [options] <model-in> <training-examples-in> <model-out>\n" 42 "nnet-train-parallel --num-threads=8 1.nnet ark:1.1.egs 2.nnet\n";
44 bool binary_write =
true;
45 bool zero_stats =
true;
46 int32 minibatch_size = 1024;
50 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
51 po.
Register(
"zero-stats", &zero_stats,
"If true, zero stats " 52 "stored with the neural net (only affects mixing up).");
54 "Seed for random number generator (e.g., for dropout)");
56 "in the parallel update. [Note: if you use a parallel " 57 "implementation of BLAS, the actual number of threads may be larger.]");
58 po.
Register(
"minibatch-size", &minibatch_size,
"Number of examples to use for " 59 "each minibatch during training.");
69 std::string nnet_rxfilename = po.
GetArg(1),
70 examples_rspecifier = po.
GetArg(2),
71 nnet_wxfilename = po.
GetArg(3);
77 Input ki(nnet_rxfilename, &binary_read);
86 double num_examples = 0;
97 Output ko(nnet_wxfilename, binary_write);
102 KALDI_LOG <<
"Finished training, processed " << num_examples
103 <<
" training examples (weighted). Wrote model to " 105 return (num_examples == 0 ? 1 : 0);
106 }
catch(
const std::exception &e) {
107 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &is, bool binary)
int main(int argc, char *argv[])
void Register(const std::string &name, bool *ptr, const std::string &doc)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
double DoBackpropParallel(const Nnet &nnet, int32 minibatch_size, SequentialNnetExampleReader *examples_reader, double *tot_weight, Nnet *nnet_to_update)
This function is similar to "DoBackprop" in nnet-update.h This function computes the objective functi...
const Nnet & GetNnet() const