nnet3-latgen-faster-parallel.cc File Reference
Include dependency graph for nnet3-latgen-faster-parallel.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 35 of file nnet3-latgen-faster-parallel.cc.

References NnetSimpleComputationOptions::acoustic_scale, kaldi::ClassifyRspecifier(), kaldi::nnet3::CollapseModel(), LatticeFasterDecoderConfig::determinize_lattice, SequentialTableReader< Holder >::Done(), Timer::Elapsed(), NnetSimpleComputationOptions::frame_subsampling_factor, ParseOptions::GetArg(), AmNnetSimple::GetNnet(), ParseOptions::GetOptArg(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), KALDI_ERR, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), kaldi::kNoRspecifier, SequentialTableReader< Holder >::Next(), TaskSequencerConfig::num_threads, ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), TableWriter< Holder >::Open(), ParseOptions::PrintUsage(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), fst::ReadFstKaldiGeneric(), LatticeFasterDecoderConfig::Register(), NnetSimpleComputationOptions::Register(), ParseOptions::Register(), TaskSequencerConfig::Register(), Timer::Reset(), TaskSequencer< C >::Run(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), RandomAccessTableReaderMapped< Holder >::Value(), and TaskSequencer< C >::Wait().

35  {
36  // note: making this program work with GPUs is as simple as initializing the
37  // device, but it probably won't make a huge difference in speed for typical
38  // setups.
39  try {
40  using namespace kaldi;
41  using namespace kaldi::nnet3;
42  typedef kaldi::int32 int32;
43  using fst::SymbolTable;
44  using fst::Fst;
45  using fst::StdArc;
46 
47  const char *usage =
48  "Generate lattices using nnet3 neural net model. This version supports\n"
49  "multiple decoding threads (using a shared decoding graph.)\n"
50  "Usage: nnet3-latgen-faster-parallel [options] <nnet-in> <fst-in|fsts-rspecifier> <features-rspecifier>"
51  " <lattice-wspecifier> [ <words-wspecifier> [<alignments-wspecifier>] ]\n"
52  "See also: nnet3-latgen-faster-batch (which supports GPUs)\n";
53  ParseOptions po(usage);
54 
55  Timer timer;
56  bool allow_partial = false;
57  TaskSequencerConfig sequencer_config; // has --num-threads option
59  NnetSimpleComputationOptions decodable_opts;
60 
61  std::string word_syms_filename;
62  std::string ivector_rspecifier,
63  online_ivector_rspecifier,
64  utt2spk_rspecifier;
65  int32 online_ivector_period = 0;
66  sequencer_config.Register(&po);
67  config.Register(&po);
68  decodable_opts.Register(&po);
69  po.Register("word-symbol-table", &word_syms_filename,
70  "Symbol table for words [for debug output]");
71  po.Register("allow-partial", &allow_partial,
72  "If true, produce output even if end state was not reached.");
73  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
74  "iVectors as vectors (i.e. not estimated online); per utterance "
75  "by default, or per speaker if you provide the --utt2spk option.");
76  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
77  "iVectors estimated online, as matrices. If you supply this,"
78  " you must set the --online-ivector-period option.");
79  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
80  "between iVectors in matrices supplied to the --online-ivectors "
81  "option");
82 
83  po.Read(argc, argv);
84 
85  if (po.NumArgs() < 4 || po.NumArgs() > 6) {
86  po.PrintUsage();
87  exit(1);
88  }
89 
90  std::string model_in_filename = po.GetArg(1),
91  fst_in_str = po.GetArg(2),
92  feature_rspecifier = po.GetArg(3),
93  lattice_wspecifier = po.GetArg(4),
94  words_wspecifier = po.GetOptArg(5),
95  alignment_wspecifier = po.GetOptArg(6);
96 
97  TaskSequencer<DecodeUtteranceLatticeFasterClass> sequencer(sequencer_config);
98  TransitionModel trans_model;
99  AmNnetSimple am_nnet;
100  {
101  bool binary;
102  Input ki(model_in_filename, &binary);
103  trans_model.Read(ki.Stream(), binary);
104  am_nnet.Read(ki.Stream(), binary);
105  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
106  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
107  CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));
108  }
109 
110  bool determinize = config.determinize_lattice;
111  CompactLatticeWriter compact_lattice_writer;
112  LatticeWriter lattice_writer;
113  if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
114  : lattice_writer.Open(lattice_wspecifier)))
115  KALDI_ERR << "Could not open table for writing lattices: "
116  << lattice_wspecifier;
117 
118  RandomAccessBaseFloatMatrixReader online_ivector_reader(
119  online_ivector_rspecifier);
121  ivector_rspecifier, utt2spk_rspecifier);
122 
123  Int32VectorWriter words_writer(words_wspecifier);
124  Int32VectorWriter alignment_writer(alignment_wspecifier);
125 
126  fst::SymbolTable *word_syms = NULL;
127  if (word_syms_filename != "")
128  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
129  KALDI_ERR << "Could not read symbol table from file "
130  << word_syms_filename;
131 
132  double tot_like = 0.0;
133  kaldi::int64 frame_count = 0;
134  int num_success = 0, num_fail = 0;
135 
136  if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) {
137  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
138 
139  // Input FST is just one FST, not a table of FSTs.
140  Fst<StdArc> *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str);
141  timer.Reset();
142 
143  {
144  for (; !feature_reader.Done(); feature_reader.Next()) {
145  std::string utt = feature_reader.Key();
146  const Matrix<BaseFloat> &features (feature_reader.Value());
147  if (features.NumRows() == 0) {
148  KALDI_WARN << "Zero-length utterance: " << utt;
149  num_fail++;
150  continue;
151  }
152  const Matrix<BaseFloat> *online_ivectors = NULL;
153  const Vector<BaseFloat> *ivector = NULL;
154  if (!ivector_rspecifier.empty()) {
155  if (!ivector_reader.HasKey(utt)) {
156  KALDI_WARN << "No iVector available for utterance " << utt;
157  num_fail++;
158  continue;
159  } else {
160  ivector = &ivector_reader.Value(utt);
161  }
162  }
163  if (!online_ivector_rspecifier.empty()) {
164  if (!online_ivector_reader.HasKey(utt)) {
165  KALDI_WARN << "No online iVector available for utterance " << utt;
166  num_fail++;
167  continue;
168  } else {
169  online_ivectors = &online_ivector_reader.Value(utt);
170  }
171  }
172 
173  LatticeFasterDecoder *decoder =
174  new LatticeFasterDecoder(*decode_fst, config);
175 
176  DecodableInterface *nnet_decodable = new
178  decodable_opts, trans_model, am_nnet,
179  features, ivector, online_ivectors,
180  online_ivector_period);
181 
184  decoder, nnet_decodable, // takes ownership of these two.
185  trans_model, word_syms, utt, decodable_opts.acoustic_scale,
186  determinize, allow_partial, &alignment_writer, &words_writer,
187  &compact_lattice_writer, &lattice_writer,
188  &tot_like, &frame_count, &num_success, &num_fail, NULL);
189 
190  sequencer.Run(task); // takes ownership of "task",
191  // and will delete it when done.
192  }
193  }
194  sequencer.Wait(); // Waits for all tasks to be done.
195  delete decode_fst;
196  } else { // We have different FSTs for different utterances.
197  SequentialTableReader<fst::VectorFstHolder> fst_reader(fst_in_str);
198  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
199  for (; !fst_reader.Done(); fst_reader.Next()) {
200  std::string utt = fst_reader.Key();
201  if (!feature_reader.HasKey(utt)) {
202  KALDI_WARN << "Not decoding utterance " << utt
203  << " because no features available.";
204  num_fail++;
205  continue;
206  }
207  const Matrix<BaseFloat> &features = feature_reader.Value(utt);
208  if (features.NumRows() == 0) {
209  KALDI_WARN << "Zero-length utterance: " << utt;
210  num_fail++;
211  continue;
212  }
213 
214  const Matrix<BaseFloat> *online_ivectors = NULL;
215  const Vector<BaseFloat> *ivector = NULL;
216  if (!ivector_rspecifier.empty()) {
217  if (!ivector_reader.HasKey(utt)) {
218  KALDI_WARN << "No iVector available for utterance " << utt;
219  num_fail++;
220  continue;
221  } else {
222  ivector = &ivector_reader.Value(utt);
223  }
224  }
225  if (!online_ivector_rspecifier.empty()) {
226  if (!online_ivector_reader.HasKey(utt)) {
227  KALDI_WARN << "No online iVector available for utterance " << utt;
228  num_fail++;
229  continue;
230  } else {
231  online_ivectors = &online_ivector_reader.Value(utt);
232  }
233  }
234 
235  // the following constructor takes ownership of the FST pointer so that
236  // it is deleted when 'decoder' is deleted.
237  LatticeFasterDecoder *decoder =
238  new LatticeFasterDecoder(config, fst_reader.Value().Copy());
239 
240  DecodableInterface *nnet_decodable = new
242  decodable_opts, trans_model, am_nnet,
243  features, ivector, online_ivectors,
244  online_ivector_period);
245 
248  decoder, nnet_decodable, // takes ownership of these two.
249  trans_model, word_syms, utt, decodable_opts.acoustic_scale,
250  determinize, allow_partial, &alignment_writer, &words_writer,
251  &compact_lattice_writer, &lattice_writer,
252  &tot_like, &frame_count, &num_success, &num_fail, NULL);
253 
254  sequencer.Run(task); // takes ownership of "task",
255  // and will delete it when done.
256  }
257  sequencer.Wait(); // Waits for all tasks to be done.
258  }
259 
260  kaldi::int64 input_frame_count =
261  frame_count * decodable_opts.frame_subsampling_factor;
262 
263  double elapsed = timer.Elapsed();
264  KALDI_LOG << "Time taken " << elapsed
265  << "s: real-time factor assuming 100 feature frames/sec is "
266  << (sequencer_config.num_threads * elapsed * 100.0 /
267  input_frame_count);
268  KALDI_LOG << "Done " << num_success << " utterances, failed for "
269  << num_fail;
270  KALDI_LOG << "Overall log-likelihood per frame is "
271  << (tot_like / frame_count) << " over "
272  << frame_count << " frames.";
273 
274  delete word_syms;
275  if (num_success != 0) return 0;
276  else return 1;
277  } catch(const std::exception &e) {
278  std::cerr << e.what();
279  return -1;
280  }
281 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
bool Open(const std::string &wspecifier)
DecodableInterface provides a link between the (acoustic-modeling and feature-processing) code and th...
Definition: decodable-itf.h:82
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
void Reset()
Definition: timer.h:71
fst::StdArc StdArc
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
RspecifierType ClassifyRspecifier(const std::string &rspecifier, std::string *rxfilename, RspecifierOptions *opts)
Definition: kaldi-table.cc:225
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
This class basically does the same job as the function DecodeUtteranceLatticeFaster, but in a way that allows us to build a multi-threaded command line program more easily.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
This is the "normal" lattice-generating decoder.
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
LatticeFasterDecoderTpl< fst::StdFst, decoder::StdToken > LatticeFasterDecoder
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
void Register(OptionsItf *opts)
Definition: kaldi-thread.h:160
Config class for the CollapseModel function.
Definition: nnet-utils.h:240