nnet3-discriminative-compute-objf.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-discriminative-compute-objf.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
24 #include "nnet3/am-nnet-simple.h"
25 #include "nnet3/nnet-utils.h"
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet3;
31  typedef kaldi::int32 int32;
32  typedef kaldi::int64 int64;
33 
34  const char *usage =
35  "Computes and prints to in logging messages the objective function per frame of\n"
36  "the given data with an nnet3 neural net. The input of this is the output of\n"
37  "e.g. nnet3-discriminative-get-egs | nnet3-discriminative-merge-egs.\n"
38  "\n"
39  "Usage: nnet3-discrminative-compute-objf [options] <nnet3-model-in> <training-examples-in>\n"
40  "e.g.: nnet3-discriminative-compute-objf 0.mdl ark:valid.degs\n";
41 
42  bool batchnorm_test_mode = true, dropout_test_mode = true;
43 
44  // This program doesn't support using a GPU, because these probabilities are
45  // used for diagnostics, and you can just compute them with a small enough
46  // amount of data that a CPU can do it within reasonable time.
47  // It wouldn't be hard to make it support GPU, though.
48 
49  NnetComputeProbOptions nnet_opts;
50  discriminative::DiscriminativeOptions discriminative_opts;
51 
52  ParseOptions po(usage);
53 
54  po.Register("batchnorm-test-mode", &batchnorm_test_mode,
55  "If true, set test-mode to true on any BatchNormComponents.");
56  po.Register("dropout-test-mode", &dropout_test_mode,
57  "If true, set test-mode to true on any DropoutComponents and "
58  "DropoutMaskComponents.");
59 
60  nnet_opts.Register(&po);
61  discriminative_opts.Register(&po);
62 
63  po.Read(argc, argv);
64 
65  if (po.NumArgs() != 2) {
66  po.PrintUsage();
67  exit(1);
68  }
69 
70  std::string model_rxfilename = po.GetArg(1),
71  examples_rspecifier = po.GetArg(2);
72 
73  TransitionModel tmodel;
74  AmNnetSimple am_nnet;
75 
76  {
77  bool binary;
78  Input ki(model_rxfilename, &binary);
79  tmodel.Read(ki.Stream(), binary);
80  am_nnet.Read(ki.Stream(), binary);
81  }
82 
83  Nnet* nnet = &(am_nnet.GetNnet());
84 
85  if (batchnorm_test_mode)
86  SetBatchnormTestMode(true, nnet);
87 
88  if (dropout_test_mode)
89  SetDropoutTestMode(true, nnet);
90 
91  NnetDiscriminativeComputeObjf discriminative_objf_computer(nnet_opts,
92  discriminative_opts,
93  tmodel, am_nnet.Priors(),
94  *nnet);
95 
96  SequentialNnetDiscriminativeExampleReader example_reader(examples_rspecifier);
97 
98  for (; !example_reader.Done(); example_reader.Next())
99  discriminative_objf_computer.Compute(example_reader.Value());
100 
101  bool ok = discriminative_objf_computer.PrintTotalStats();
102 
103  return (ok ? 0 : 1);
104  } catch(const std::exception &e) {
105  std::cerr << e.what() << '\n';
106  return -1;
107  }
108 }
109 
110 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
kaldi::int32 int32
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
This file contains some miscellaneous functions dealing with class Nnet.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
const VectorBase< BaseFloat > & Priors() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).