26 int main(
int argc,
char *argv[]) {
28 using namespace kaldi;
31 typedef kaldi::int64 int64;
34 "Train nnet3 neural network parameters with discriminative sequence objective \n" 35 "gradient descent. Minibatches are to be created by nnet3-discriminative-merge-egs in\n" 36 "the input pipeline. This training program is single-threaded (best to\n" 37 "use it with a GPU).\n" 39 "Usage: nnet3-discriminative-train [options] <nnet-in> <discriminative-training-examples-in> <raw-nnet-out>\n" 41 "nnet3-discriminative-train 1.mdl 'ark:nnet3-merge-egs 1.degs ark:-|' 2.raw\n";
43 bool binary_write =
true;
44 std::string use_gpu =
"yes";
45 bool dropout_test_mode =
true;
50 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
52 "yes|no|optional|wait, only has effect if compiled with CUDA");
53 po.
Register(
"dropout-test-mode", &dropout_test_mode,
54 "If true, set test-mode to true on any DropoutComponents and " 55 "DropoutMaskComponents.");
67 CuDevice::Instantiate().SelectGpuId(use_gpu);
70 std::string model_rxfilename = po.
GetArg(1),
71 examples_rspecifier = po.
GetArg(2),
72 model_wxfilename = po.
GetArg(3);
78 Input ki(model_rxfilename, &binary);
85 if (dropout_test_mode)
94 for (; !example_reader.
Done(); example_reader.
Next())
100 CuDevice::Instantiate().PrintProfile();
102 Output ko(model_wxfilename, binary_write);
105 KALDI_LOG <<
"Wrote raw nnet model to " << model_wxfilename;
107 }
catch(
const std::exception &e) {
108 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void Write(std::ostream &ostream, bool binary) const
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
This file contains some miscellaneous functions dealing with class Nnet.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Train(const NnetDiscriminativeExample &eg)
void Read(std::istream &is, bool binary)
int main(int argc, char *argv[])
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
This class is for single-threaded discriminative training of neural nets.
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool PrintTotalStats() const
const VectorBase< BaseFloat > & Priors() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
Provides a vector abstraction class.
void Register(OptionsItf *opts)