nnet3-compute-batch.cc File Reference
Include dependency graph for nnet3-compute-batch.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file nnet3-compute-batch.cc.

References NnetBatchInference::AcceptInput(), NnetSimpleComputationOptions::acoustic_scale, MatrixBase< Real >::ApplyExp(), kaldi::nnet3::CollapseModel(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), NnetBatchInference::Finished(), ParseOptions::GetArg(), AmNnetSimple::GetNnet(), NnetBatchInference::GetOutput(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), AmNnetSimple::Priors(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), NnetBatchComputerOptions::Register(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), RandomAccessTableReaderMapped< Holder >::Value(), and TableWriter< Holder >::Write().

29  {
30  try {
31  using namespace kaldi;
32  using namespace kaldi::nnet3;
33  typedef kaldi::int32 int32;
34  typedef kaldi::int64 int64;
35 
36  const char *usage =
37  "Propagate the features through raw neural network model "
38  "and write the output. This version is optimized for GPU use. "
39  "If --apply-exp=true, apply the Exp() function to the output "
40  "before writing it out.\n"
41  "\n"
42  "Usage: nnet3-compute-batch [options] <nnet-in> <features-rspecifier> "
43  "<matrix-wspecifier>\n"
44  " e.g.: nnet3-compute-batch final.raw scp:feats.scp "
45  "ark:nnet_prediction.ark\n";
46 
47  ParseOptions po(usage);
48  Timer timer;
49 
51  opts.acoustic_scale = 1.0; // by default do no scaling
52 
53  bool apply_exp = false, use_priors = false;
54  std::string use_gpu = "yes";
55 
56  std::string word_syms_filename;
57  std::string ivector_rspecifier,
58  online_ivector_rspecifier,
59  utt2spk_rspecifier;
60  int32 online_ivector_period = 0;
61  opts.Register(&po);
62 
63  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
64  "iVectors as vectors (i.e. not estimated online); per "
65  "utterance by default, or per speaker if you provide the "
66  "--utt2spk option.");
67  po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for "
68  "utt2spk option used to get ivectors per speaker");
69  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
70  "iVectors estimated online, as matrices. If you supply this,"
71  " you must set the --online-ivector-period option.");
72  po.Register("online-ivector-period", &online_ivector_period, "Number of "
73  "frames between iVectors in matrices supplied to the "
74  "--online-ivectors option");
75  po.Register("apply-exp", &apply_exp, "If true, apply exp function to "
76  "output");
77  po.Register("use-gpu", &use_gpu,
78  "yes|no|optional|wait, only has effect if compiled with CUDA");
79  po.Register("use-priors", &use_priors, "If true, subtract the logs of the "
80  "priors stored with the model (in this case, "
81  "a .mdl file is expected as input).");
82 
83 #if HAVE_CUDA==1
84  CuDevice::RegisterDeviceOptions(&po);
85 #endif
86 
87  po.Read(argc, argv);
88 
89  if (po.NumArgs() != 3) {
90  po.PrintUsage();
91  exit(1);
92  }
93 
94 #if HAVE_CUDA==1
95  CuDevice::Instantiate().AllowMultithreading();
96  CuDevice::Instantiate().SelectGpuId(use_gpu);
97 #endif
98 
99  std::string nnet_rxfilename = po.GetArg(1),
100  feature_rspecifier = po.GetArg(2),
101  matrix_wspecifier = po.GetArg(3);
102 
103  Nnet raw_nnet;
104  AmNnetSimple am_nnet;
105  if (use_priors) {
106  bool binary;
107  TransitionModel trans_model;
108  Input ki(nnet_rxfilename, &binary);
109  trans_model.Read(ki.Stream(), binary);
110  am_nnet.Read(ki.Stream(), binary);
111  } else {
112  ReadKaldiObject(nnet_rxfilename, &raw_nnet);
113  }
114  Nnet &nnet = (use_priors ? am_nnet.GetNnet() : raw_nnet);
115  SetBatchnormTestMode(true, &nnet);
116  SetDropoutTestMode(true, &nnet);
118 
119  Vector<BaseFloat> priors;
120  if (use_priors)
121  priors = am_nnet.Priors();
122 
123  RandomAccessBaseFloatMatrixReader online_ivector_reader(
124  online_ivector_rspecifier);
126  ivector_rspecifier, utt2spk_rspecifier);
127 
128  BaseFloatMatrixWriter matrix_writer(matrix_wspecifier);
129 
130  int32 num_success = 0, num_fail = 0;
131  std::string output_uttid;
132  Matrix<BaseFloat> output_matrix;
133 
134 
135  NnetBatchInference inference(opts, nnet, priors);
136 
137  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
138 
139  for (; !feature_reader.Done(); feature_reader.Next()) {
140  std::string utt = feature_reader.Key();
141  const Matrix<BaseFloat> &features = feature_reader.Value();
142  if (features.NumRows() == 0) {
143  KALDI_WARN << "Zero-length utterance: " << utt;
144  num_fail++;
145  continue;
146  }
147  const Matrix<BaseFloat> *online_ivectors = NULL;
148  const Vector<BaseFloat> *ivector = NULL;
149  if (!ivector_rspecifier.empty()) {
150  if (!ivector_reader.HasKey(utt)) {
151  KALDI_WARN << "No iVector available for utterance " << utt;
152  num_fail++;
153  continue;
154  } else {
155  ivector = new Vector<BaseFloat>(ivector_reader.Value(utt));
156  }
157  }
158  if (!online_ivector_rspecifier.empty()) {
159  if (!online_ivector_reader.HasKey(utt)) {
160  KALDI_WARN << "No online iVector available for utterance " << utt;
161  num_fail++;
162  continue;
163  } else {
164  online_ivectors = new Matrix<BaseFloat>(
165  online_ivector_reader.Value(utt));
166  }
167  }
168 
169  inference.AcceptInput(utt, features, ivector, online_ivectors,
170  online_ivector_period);
171 
172  std::string output_key;
173  Matrix<BaseFloat> output;
174  while (inference.GetOutput(&output_key, &output)) {
175  if (apply_exp)
176  output.ApplyExp();
177  matrix_writer.Write(output_key, output);
178  num_success++;
179  }
180  }
181 
182  inference.Finished();
183  std::string output_key;
184  Matrix<BaseFloat> output;
185  while (inference.GetOutput(&output_key, &output)) {
186  if (apply_exp)
187  output.ApplyExp();
188  matrix_writer.Write(output_key, output);
189  num_success++;
190  }
191 #if HAVE_CUDA==1
192  CuDevice::Instantiate().PrintProfile();
193 #endif
194  double elapsed = timer.Elapsed();
195  KALDI_LOG << "Time taken "<< elapsed << "s";
196  KALDI_LOG << "Done " << num_success << " utterances, failed for "
197  << num_fail;
198 
199  if (num_success != 0) {
200  return 0;
201  } else {
202  return 1;
203  }
204  } catch(const std::exception &e) {
205  std::cerr << e.what();
206  return -1;
207  }
208 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
This class implements a simplified interface to class NnetBatchComputer, which is suitable for progra...
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
const VectorBase< BaseFloat > & Priors() const
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
Config class for the CollapseModel function.
Definition: nnet-utils.h:240