29 using namespace kaldi;
32 typedef kaldi::int64 int64;
35 "Computes and prints to in logging messages the objective function per frame of\n" 36 "the given data with an nnet3 neural net. The input of this is the output of\n" 37 "e.g. nnet3-discriminative-get-egs | nnet3-discriminative-merge-egs.\n" 39 "Usage: nnet3-discrminative-compute-objf [options] <nnet3-model-in> <training-examples-in>\n" 40 "e.g.: nnet3-discriminative-compute-objf 0.mdl ark:valid.degs\n";
42 bool batchnorm_test_mode =
true, dropout_test_mode =
true;
54 po.Register(
"batchnorm-test-mode", &batchnorm_test_mode,
55 "If true, set test-mode to true on any BatchNormComponents.");
56 po.Register(
"dropout-test-mode", &dropout_test_mode,
57 "If true, set test-mode to true on any DropoutComponents and " 58 "DropoutMaskComponents.");
65 if (po.NumArgs() != 2) {
70 std::string model_rxfilename = po.GetArg(1),
71 examples_rspecifier = po.GetArg(2);
78 Input ki(model_rxfilename, &binary);
79 tmodel.
Read(ki.Stream(), binary);
80 am_nnet.
Read(ki.Stream(), binary);
85 if (batchnorm_test_mode)
88 if (dropout_test_mode)
98 for (; !example_reader.Done(); example_reader.Next())
99 discriminative_objf_computer.Compute(example_reader.Value());
101 bool ok = discriminative_objf_computer.PrintTotalStats();
104 }
catch(
const std::exception &e) {
105 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
void Register(OptionsItf *opts)
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Register(OptionsItf *opts)
void Read(std::istream &is, bool binary)
This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
const VectorBase< BaseFloat > & Priors() const