nnet3-compute.cc File Reference
Include dependency graph for nnet3-compute.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 29 of file nnet3-compute.cc.

References NnetSimpleComputationOptions::acoustic_scale, kaldi::nnet3::CollapseModel(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), ParseOptions::GetArg(), AmNnetSimple::GetNnet(), DecodableNnetSimple::GetOutputForFrame(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), DecodableNnetSimple::NumFrames(), NnetSimpleComputationOptions::optimize_config, DecodableNnetSimple::OutputDim(), ParseOptions::PrintUsage(), AmNnetSimple::Priors(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), NnetSimpleComputationOptions::Register(), ParseOptions::Register(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), RandomAccessTableReaderMapped< Holder >::Value(), and TableWriter< Holder >::Write().

29  {
30  try {
31  using namespace kaldi;
32  using namespace kaldi::nnet3;
33  typedef kaldi::int32 int32;
34  typedef kaldi::int64 int64;
35 
36  const char *usage =
37  "Propagate the features through raw neural network model "
38  "and write the output.\n"
39  "If --apply-exp=true, apply the Exp() function to the output "
40  "before writing it out.\n"
41  "\n"
42  "Usage: nnet3-compute [options] <nnet-in> <features-rspecifier> <matrix-wspecifier>\n"
43  " e.g.: nnet3-compute final.raw scp:feats.scp ark:nnet_prediction.ark\n"
44  "See also: nnet3-compute-from-egs, nnet3-chain-compute-post\n"
45  "Note: this program does not currently make very efficient use of the GPU.\n";
46 
47  ParseOptions po(usage);
48  Timer timer;
49 
51  opts.acoustic_scale = 1.0; // by default do no scaling.
52 
53  bool apply_exp = false, use_priors = false;
54  std::string use_gpu = "yes";
55 
56  std::string ivector_rspecifier,
57  online_ivector_rspecifier,
58  utt2spk_rspecifier;
59  int32 online_ivector_period = 0;
60  opts.Register(&po);
61 
62  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
63  "iVectors as vectors (i.e. not estimated online); per utterance "
64  "by default, or per speaker if you provide the --utt2spk option.");
65  po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for "
66  "utt2spk option used to get ivectors per speaker");
67  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
68  "iVectors estimated online, as matrices. If you supply this,"
69  " you must set the --online-ivector-period option.");
70  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
71  "between iVectors in matrices supplied to the --online-ivectors "
72  "option");
73  po.Register("apply-exp", &apply_exp, "If true, apply exp function to "
74  "output");
75  po.Register("use-gpu", &use_gpu,
76  "yes|no|optional|wait, only has effect if compiled with CUDA");
77  po.Register("use-priors", &use_priors, "If true, subtract the logs of the "
78  "priors stored with the model (in this case, "
79  "a .mdl file is expected as input).");
80 
81 #if HAVE_CUDA==1
82  CuDevice::RegisterDeviceOptions(&po);
83 #endif
84 
85  po.Read(argc, argv);
86 
87  if (po.NumArgs() != 3) {
88  po.PrintUsage();
89  exit(1);
90  }
91 
92 #if HAVE_CUDA==1
93  CuDevice::Instantiate().SelectGpuId(use_gpu);
94 #endif
95 
96  std::string nnet_rxfilename = po.GetArg(1),
97  feature_rspecifier = po.GetArg(2),
98  matrix_wspecifier = po.GetArg(3);
99 
100  Nnet raw_nnet;
101  AmNnetSimple am_nnet;
102  if (use_priors) {
103  bool binary;
104  TransitionModel trans_model;
105  Input ki(nnet_rxfilename, &binary);
106  trans_model.Read(ki.Stream(), binary);
107  am_nnet.Read(ki.Stream(), binary);
108  } else {
109  ReadKaldiObject(nnet_rxfilename, &raw_nnet);
110  }
111  Nnet &nnet = (use_priors ? am_nnet.GetNnet() : raw_nnet);
112  SetBatchnormTestMode(true, &nnet);
113  SetDropoutTestMode(true, &nnet);
115 
116  Vector<BaseFloat> priors;
117  if (use_priors)
118  priors = am_nnet.Priors();
119 
120  RandomAccessBaseFloatMatrixReader online_ivector_reader(
121  online_ivector_rspecifier);
123  ivector_rspecifier, utt2spk_rspecifier);
124 
125  CachingOptimizingCompiler compiler(nnet, opts.optimize_config);
126 
127  BaseFloatMatrixWriter matrix_writer(matrix_wspecifier);
128 
129  int32 num_success = 0, num_fail = 0;
130  int64 frame_count = 0;
131 
132  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
133 
134  for (; !feature_reader.Done(); feature_reader.Next()) {
135  std::string utt = feature_reader.Key();
136  const Matrix<BaseFloat> &features (feature_reader.Value());
137  if (features.NumRows() == 0) {
138  KALDI_WARN << "Zero-length utterance: " << utt;
139  num_fail++;
140  continue;
141  }
142  const Matrix<BaseFloat> *online_ivectors = NULL;
143  const Vector<BaseFloat> *ivector = NULL;
144  if (!ivector_rspecifier.empty()) {
145  if (!ivector_reader.HasKey(utt)) {
146  KALDI_WARN << "No iVector available for utterance " << utt;
147  num_fail++;
148  continue;
149  } else {
150  ivector = &ivector_reader.Value(utt);
151  }
152  }
153  if (!online_ivector_rspecifier.empty()) {
154  if (!online_ivector_reader.HasKey(utt)) {
155  KALDI_WARN << "No online iVector available for utterance " << utt;
156  num_fail++;
157  continue;
158  } else {
159  online_ivectors = &online_ivector_reader.Value(utt);
160  }
161  }
162 
163  DecodableNnetSimple nnet_computer(
164  opts, nnet, priors,
165  features, &compiler,
166  ivector, online_ivectors,
167  online_ivector_period);
168 
169  Matrix<BaseFloat> matrix(nnet_computer.NumFrames(),
170  nnet_computer.OutputDim());
171  for (int32 t = 0; t < nnet_computer.NumFrames(); t++) {
172  SubVector<BaseFloat> row(matrix, t);
173  nnet_computer.GetOutputForFrame(t, &row);
174  }
175 
176  if (apply_exp)
177  matrix.ApplyExp();
178 
179  matrix_writer.Write(utt, matrix);
180 
181  frame_count += features.NumRows();
182  num_success++;
183  }
184 
185 #if HAVE_CUDA==1
186  CuDevice::Instantiate().PrintProfile();
187 #endif
188  double elapsed = timer.Elapsed();
189  KALDI_LOG << "Time taken "<< elapsed
190  << "s: real-time factor assuming 100 frames/sec is "
191  << (elapsed*100.0/frame_count);
192  KALDI_LOG << "Done " << num_success << " utterances, failed for "
193  << num_fail;
194 
195  if (num_success != 0) return 0;
196  else return 1;
197  } catch(const std::exception &e) {
198  std::cerr << e.what();
199  return -1;
200  }
201 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
const VectorBase< BaseFloat > & Priors() const
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
Config class for the CollapseModel function.
Definition: nnet-utils.h:240