28 int main(
int argc,
char *argv[]) {
30 using namespace kaldi;
33 typedef kaldi::int64 int64;
36 "Train an ensemble of neural networks with backprop and stochastic\n" 37 "gradient descent using minibatches. Modified version of nnet-train-simple.\n" 38 "Implements parallel gradient descent with a term that encourages the nnets to\n" 39 "produce similar outputs.\n" 41 "Usage: nnet-train-ensemble [options] <model-in-1> <model-in-2> ... <model-in-n>" 42 " <training-examples-in> <model-out-1> <model-out-2> ... <model-out-n> \n" 45 " nnet-train-ensemble 1.1.nnet 2.1.nnet ark:egs.ark 2.1.nnet 2.2.nnet \n";
47 bool binary_write =
true;
48 bool zero_stats =
true;
50 std::string use_gpu =
"yes";
54 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
55 po.
Register(
"zero-stats", &zero_stats,
"If true, zero occupation " 56 "counts stored with the neural net (only affects mixing up).");
57 po.
Register(
"srand", &srand_seed,
"Seed for random number generator " 58 "(relevant if you have layers of type AffineComponentPreconditioned " 59 "with l2-penalty != 0.0");
61 "yes|no|optional|wait, only has effect if compiled with CUDA");
74 CuDevice::Instantiate().SelectGpuId(use_gpu);
77 int32 num_nnets = (po.
NumArgs() - 1) / 2;
78 std::string nnet_rxfilename = po.
GetArg(1);
79 std::string examples_rspecifier = po.
GetArg(num_nnets + 1);
81 std::string nnet1_rxfilename = po.
GetArg(1);
84 std::vector<AmNnet> am_nnets(num_nnets);
87 Input ki(nnet1_rxfilename, &binary_read);
90 am_nnets[0].Read(ki.
Stream(), binary_read);
93 std::vector<Nnet*> nnets(num_nnets);
94 nnets[0] = &(am_nnets[0].GetNnet());
96 for (int32
n = 1;
n < num_nnets;
n++) {
100 trans_model.
Read(ki.Stream(), binary_read);
101 am_nnets[
n].Read(ki.Stream(), binary_read);
102 nnets[
n] = &am_nnets[
n].GetNnet();
106 int64 num_examples = 0;
110 for (int32
n = 1;
n < num_nnets;
n++)
111 nnets[
n]->ZeroStats();
120 for (; !example_reader.
Done(); example_reader.
Next(), num_examples++)
125 for (int32
n = 0;
n < num_nnets;
n++) {
128 am_nnets[
n].Write(ko.
Stream(), binary_write);
133 CuDevice::Instantiate().PrintProfile();
136 KALDI_LOG <<
"Finished training, processed " << num_examples
137 <<
" training examples.";
138 return (num_examples == 0 ? 1 : 0);
139 }
catch(
const std::exception &e) {
140 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
void Register(const std::string &name, bool *ptr, const std::string &doc)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void TrainOnExample(const NnetExample &value)
TrainOnExample will take the example and add it to a buffer; if we've reached the minibatch size it w...
void Register(OptionsItf *opts)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const