nnet3-discriminative-compute-objf.cc File Reference
Include dependency graph for nnet3-discriminative-compute-objf.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file nnet3-discriminative-compute-objf.cc.

References ParseOptions::GetArg(), AmNnetSimple::GetNnet(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), AmNnetSimple::Priors(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), ParseOptions::Register(), NnetComputeProbOptions::Register(), DiscriminativeOptions::Register(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), and Input::Stream().

27  {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet3;
31  typedef kaldi::int32 int32;
32  typedef kaldi::int64 int64;
33 
34  const char *usage =
35  "Computes and prints to in logging messages the objective function per frame of\n"
36  "the given data with an nnet3 neural net. The input of this is the output of\n"
37  "e.g. nnet3-discriminative-get-egs | nnet3-discriminative-merge-egs.\n"
38  "\n"
39  "Usage: nnet3-discrminative-compute-objf [options] <nnet3-model-in> <training-examples-in>\n"
40  "e.g.: nnet3-discriminative-compute-objf 0.mdl ark:valid.degs\n";
41 
42  bool batchnorm_test_mode = true, dropout_test_mode = true;
43 
44  // This program doesn't support using a GPU, because these probabilities are
45  // used for diagnostics, and you can just compute them with a small enough
46  // amount of data that a CPU can do it within reasonable time.
47  // It wouldn't be hard to make it support GPU, though.
48 
49  NnetComputeProbOptions nnet_opts;
50  discriminative::DiscriminativeOptions discriminative_opts;
51 
52  ParseOptions po(usage);
53 
54  po.Register("batchnorm-test-mode", &batchnorm_test_mode,
55  "If true, set test-mode to true on any BatchNormComponents.");
56  po.Register("dropout-test-mode", &dropout_test_mode,
57  "If true, set test-mode to true on any DropoutComponents and "
58  "DropoutMaskComponents.");
59 
60  nnet_opts.Register(&po);
61  discriminative_opts.Register(&po);
62 
63  po.Read(argc, argv);
64 
65  if (po.NumArgs() != 2) {
66  po.PrintUsage();
67  exit(1);
68  }
69 
70  std::string model_rxfilename = po.GetArg(1),
71  examples_rspecifier = po.GetArg(2);
72 
73  TransitionModel tmodel;
74  AmNnetSimple am_nnet;
75 
76  {
77  bool binary;
78  Input ki(model_rxfilename, &binary);
79  tmodel.Read(ki.Stream(), binary);
80  am_nnet.Read(ki.Stream(), binary);
81  }
82 
83  Nnet* nnet = &(am_nnet.GetNnet());
84 
85  if (batchnorm_test_mode)
86  SetBatchnormTestMode(true, nnet);
87 
88  if (dropout_test_mode)
89  SetDropoutTestMode(true, nnet);
90 
91  NnetDiscriminativeComputeObjf discriminative_objf_computer(nnet_opts,
92  discriminative_opts,
93  tmodel, am_nnet.Priors(),
94  *nnet);
95 
96  SequentialNnetDiscriminativeExampleReader example_reader(examples_rspecifier);
97 
98  for (; !example_reader.Done(); example_reader.Next())
99  discriminative_objf_computer.Compute(example_reader.Value());
100 
101  bool ok = discriminative_objf_computer.PrintTotalStats();
102 
103  return (ok ? 0 : 1);
104  } catch(const std::exception &e) {
105  std::cerr << e.what() << '\n';
106  return -1;
107  }
108 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
kaldi::int32 int32
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
const VectorBase< BaseFloat > & Priors() const