nnet3-compute-batch.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-compute-batch.cc
2 
3 // Copyright 2012-2018 Johns Hopkins University (author: Daniel Povey)
4 // 2018 Hang Lyu
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
25 #include "base/timer.h"
26 #include "nnet3/nnet-utils.h"
27 
28 
29 int main(int argc, char *argv[]) {
30  try {
31  using namespace kaldi;
32  using namespace kaldi::nnet3;
33  typedef kaldi::int32 int32;
34  typedef kaldi::int64 int64;
35 
36  const char *usage =
37  "Propagate the features through raw neural network model "
38  "and write the output. This version is optimized for GPU use. "
39  "If --apply-exp=true, apply the Exp() function to the output "
40  "before writing it out.\n"
41  "\n"
42  "Usage: nnet3-compute-batch [options] <nnet-in> <features-rspecifier> "
43  "<matrix-wspecifier>\n"
44  " e.g.: nnet3-compute-batch final.raw scp:feats.scp "
45  "ark:nnet_prediction.ark\n";
46 
47  ParseOptions po(usage);
48  Timer timer;
49 
51  opts.acoustic_scale = 1.0; // by default do no scaling
52 
53  bool apply_exp = false, use_priors = false;
54  std::string use_gpu = "yes";
55 
56  std::string word_syms_filename;
57  std::string ivector_rspecifier,
58  online_ivector_rspecifier,
59  utt2spk_rspecifier;
60  int32 online_ivector_period = 0;
61  opts.Register(&po);
62 
63  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
64  "iVectors as vectors (i.e. not estimated online); per "
65  "utterance by default, or per speaker if you provide the "
66  "--utt2spk option.");
67  po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for "
68  "utt2spk option used to get ivectors per speaker");
69  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
70  "iVectors estimated online, as matrices. If you supply this,"
71  " you must set the --online-ivector-period option.");
72  po.Register("online-ivector-period", &online_ivector_period, "Number of "
73  "frames between iVectors in matrices supplied to the "
74  "--online-ivectors option");
75  po.Register("apply-exp", &apply_exp, "If true, apply exp function to "
76  "output");
77  po.Register("use-gpu", &use_gpu,
78  "yes|no|optional|wait, only has effect if compiled with CUDA");
79  po.Register("use-priors", &use_priors, "If true, subtract the logs of the "
80  "priors stored with the model (in this case, "
81  "a .mdl file is expected as input).");
82 
83 #if HAVE_CUDA==1
84  CuDevice::RegisterDeviceOptions(&po);
85 #endif
86 
87  po.Read(argc, argv);
88 
89  if (po.NumArgs() != 3) {
90  po.PrintUsage();
91  exit(1);
92  }
93 
94 #if HAVE_CUDA==1
95  CuDevice::Instantiate().AllowMultithreading();
96  CuDevice::Instantiate().SelectGpuId(use_gpu);
97 #endif
98 
99  std::string nnet_rxfilename = po.GetArg(1),
100  feature_rspecifier = po.GetArg(2),
101  matrix_wspecifier = po.GetArg(3);
102 
103  Nnet raw_nnet;
104  AmNnetSimple am_nnet;
105  if (use_priors) {
106  bool binary;
107  TransitionModel trans_model;
108  Input ki(nnet_rxfilename, &binary);
109  trans_model.Read(ki.Stream(), binary);
110  am_nnet.Read(ki.Stream(), binary);
111  } else {
112  ReadKaldiObject(nnet_rxfilename, &raw_nnet);
113  }
114  Nnet &nnet = (use_priors ? am_nnet.GetNnet() : raw_nnet);
115  SetBatchnormTestMode(true, &nnet);
116  SetDropoutTestMode(true, &nnet);
118 
119  Vector<BaseFloat> priors;
120  if (use_priors)
121  priors = am_nnet.Priors();
122 
123  RandomAccessBaseFloatMatrixReader online_ivector_reader(
124  online_ivector_rspecifier);
126  ivector_rspecifier, utt2spk_rspecifier);
127 
128  BaseFloatMatrixWriter matrix_writer(matrix_wspecifier);
129 
130  int32 num_success = 0, num_fail = 0;
131  std::string output_uttid;
132  Matrix<BaseFloat> output_matrix;
133 
134 
135  NnetBatchInference inference(opts, nnet, priors);
136 
137  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
138 
139  for (; !feature_reader.Done(); feature_reader.Next()) {
140  std::string utt = feature_reader.Key();
141  const Matrix<BaseFloat> &features = feature_reader.Value();
142  if (features.NumRows() == 0) {
143  KALDI_WARN << "Zero-length utterance: " << utt;
144  num_fail++;
145  continue;
146  }
147  const Matrix<BaseFloat> *online_ivectors = NULL;
148  const Vector<BaseFloat> *ivector = NULL;
149  if (!ivector_rspecifier.empty()) {
150  if (!ivector_reader.HasKey(utt)) {
151  KALDI_WARN << "No iVector available for utterance " << utt;
152  num_fail++;
153  continue;
154  } else {
155  ivector = new Vector<BaseFloat>(ivector_reader.Value(utt));
156  }
157  }
158  if (!online_ivector_rspecifier.empty()) {
159  if (!online_ivector_reader.HasKey(utt)) {
160  KALDI_WARN << "No online iVector available for utterance " << utt;
161  num_fail++;
162  continue;
163  } else {
164  online_ivectors = new Matrix<BaseFloat>(
165  online_ivector_reader.Value(utt));
166  }
167  }
168 
169  inference.AcceptInput(utt, features, ivector, online_ivectors,
170  online_ivector_period);
171 
172  std::string output_key;
173  Matrix<BaseFloat> output;
174  while (inference.GetOutput(&output_key, &output)) {
175  if (apply_exp)
176  output.ApplyExp();
177  matrix_writer.Write(output_key, output);
178  num_success++;
179  }
180  }
181 
182  inference.Finished();
183  std::string output_key;
184  Matrix<BaseFloat> output;
185  while (inference.GetOutput(&output_key, &output)) {
186  if (apply_exp)
187  output.ApplyExp();
188  matrix_writer.Write(output_key, output);
189  num_success++;
190  }
191 #if HAVE_CUDA==1
192  CuDevice::Instantiate().PrintProfile();
193 #endif
194  double elapsed = timer.Elapsed();
195  KALDI_LOG << "Time taken "<< elapsed << "s";
196  KALDI_LOG << "Done " << num_success << " utterances, failed for "
197  << num_fail;
198 
199  if (num_success != 0) {
200  return 0;
201  } else {
202  return 1;
203  }
204  } catch(const std::exception &e) {
205  std::cerr << e.what();
206  return -1;
207  }
208 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const Nnet & GetNnet() const
void AcceptInput(const std::string &utterance_id, const Matrix< BaseFloat > &input, const Vector< BaseFloat > *ivector, const Matrix< BaseFloat > *online_ivectors, int32 online_ivector_period)
The user should call this one by one for the utterances that this class needs to compute (intersperse...
void Write(const std::string &key, const T &value) const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
This file contains some miscellaneous functions dealing with class Nnet.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
This class implements a simplified interface to class NnetBatchComputer, which is suitable for progra...
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
void Finished()
The user should call this after the last input has been provided via AcceptInput().
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
const VectorBase< BaseFloat > & Priors() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
const T & Value(const std::string &key)
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
bool GetOutput(std::string *utterance_id, Matrix< BaseFloat > *output)
The user should call this to obtain output.
Config class for the CollapseModel function.
Definition: nnet-utils.h:240