nnet3-latgen-grammar.cc File Reference
Include dependency graph for nnet3-latgen-grammar.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 33 of file nnet3-latgen-grammar.cc.

References NnetSimpleComputationOptions::acoustic_scale, kaldi::nnet3::CollapseModel(), kaldi::DecodeUtteranceLatticeFaster(), LatticeFasterDecoderConfig::determinize_lattice, Timer::Elapsed(), NnetSimpleComputationOptions::frame_subsampling_factor, ParseOptions::GetArg(), AmNnetSimple::GetNnet(), ParseOptions::GetOptArg(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), KALDI_ERR, KALDI_LOG, KALDI_WARN, ParseOptions::NumArgs(), DecodableAmNnetSimple::NumFramesReady(), TableWriter< Holder >::Open(), NnetSimpleComputationOptions::optimize_config, ParseOptions::PrintUsage(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), LatticeFasterDecoderConfig::Register(), NnetSimpleComputationOptions::Register(), ParseOptions::Register(), Timer::Reset(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), and RandomAccessTableReaderMapped< Holder >::Value().

33  {
34  // note: making this program work with GPUs is as simple as initializing the
35  // device, but it probably won't make a huge difference in speed for typical
36  // setups.
37  try {
38  using namespace kaldi;
39  using namespace kaldi::nnet3;
40  typedef kaldi::int32 int32;
41  using fst::SymbolTable;
42  using fst::Fst;
43  using fst::StdArc;
44 
45  const char *usage =
46  "Generate lattices using nnet3 neural net model, and GrammarFst-based graph\n"
47  "see kaldi-asr.org/doc/grammar.html for more context.\n"
48  "\n"
49  "Usage: nnet3-latgen-grammar [options] <nnet-in> <grammar-fst-in> <features-rspecifier>"
50  " <lattice-wspecifier> [ <words-wspecifier> [<alignments-wspecifier>] ]\n";
51 
52  ParseOptions po(usage);
53  Timer timer;
54  bool allow_partial = false;
56  NnetSimpleComputationOptions decodable_opts;
57 
58  std::string word_syms_filename;
59  std::string ivector_rspecifier,
60  online_ivector_rspecifier,
61  utt2spk_rspecifier;
62  int32 online_ivector_period = 0;
63  config.Register(&po);
64  decodable_opts.Register(&po);
65  po.Register("word-symbol-table", &word_syms_filename,
66  "Symbol table for words [for debug output]");
67  po.Register("allow-partial", &allow_partial,
68  "If true, produce output even if end state was not reached.");
69  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
70  "iVectors as vectors (i.e. not estimated online); per utterance "
71  "by default, or per speaker if you provide the --utt2spk option.");
72  po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for "
73  "utt2spk option used to get ivectors per speaker");
74  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
75  "iVectors estimated online, as matrices. If you supply this,"
76  " you must set the --online-ivector-period option.");
77  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
78  "between iVectors in matrices supplied to the --online-ivectors "
79  "option");
80 
81  po.Read(argc, argv);
82 
83  if (po.NumArgs() < 4 || po.NumArgs() > 6) {
84  po.PrintUsage();
85  exit(1);
86  }
87 
88  std::string model_rxfilename = po.GetArg(1),
89  grammar_fst_rxfilename = po.GetArg(2),
90  feature_rspecifier = po.GetArg(3),
91  lattice_wspecifier = po.GetArg(4),
92  words_wspecifier = po.GetOptArg(5),
93  alignment_wspecifier = po.GetOptArg(6);
94 
95  TransitionModel trans_model;
96  AmNnetSimple am_nnet;
97  {
98  bool binary;
99  Input ki(model_rxfilename, &binary);
100  trans_model.Read(ki.Stream(), binary);
101  am_nnet.Read(ki.Stream(), binary);
102  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
103  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
104  CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));
105  }
106 
107  bool determinize = config.determinize_lattice;
108  CompactLatticeWriter compact_lattice_writer;
109  LatticeWriter lattice_writer;
110  if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
111  : lattice_writer.Open(lattice_wspecifier)))
112  KALDI_ERR << "Could not open table for writing lattices: "
113  << lattice_wspecifier;
114 
115  RandomAccessBaseFloatMatrixReader online_ivector_reader(
116  online_ivector_rspecifier);
118  ivector_rspecifier, utt2spk_rspecifier);
119 
120  Int32VectorWriter words_writer(words_wspecifier);
121  Int32VectorWriter alignment_writer(alignment_wspecifier);
122 
123  fst::SymbolTable *word_syms = NULL;
124  if (word_syms_filename != "")
125  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
126  KALDI_ERR << "Could not read symbol table from file "
127  << word_syms_filename;
128 
129  double tot_like = 0.0;
130  kaldi::int64 frame_count = 0;
131  int num_success = 0, num_fail = 0;
132  // this compiler object allows caching of computations across
133  // different utterances.
134  CachingOptimizingCompiler compiler(am_nnet.GetNnet(),
135  decodable_opts.optimize_config);
136 
137  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
138 
140  ReadKaldiObject(grammar_fst_rxfilename, &fst);
141  timer.Reset();
142 
143  {
144  LatticeFasterDecoderTpl<fst::GrammarFst> decoder(fst, config);
145 
146  for (; !feature_reader.Done(); feature_reader.Next()) {
147  std::string utt = feature_reader.Key();
148  const Matrix<BaseFloat> &features (feature_reader.Value());
149  if (features.NumRows() == 0) {
150  KALDI_WARN << "Zero-length utterance: " << utt;
151  num_fail++;
152  continue;
153  }
154  const Matrix<BaseFloat> *online_ivectors = NULL;
155  const Vector<BaseFloat> *ivector = NULL;
156  if (!ivector_rspecifier.empty()) {
157  if (!ivector_reader.HasKey(utt)) {
158  KALDI_WARN << "No iVector available for utterance " << utt;
159  num_fail++;
160  continue;
161  } else {
162  ivector = &ivector_reader.Value(utt);
163  }
164  }
165  if (!online_ivector_rspecifier.empty()) {
166  if (!online_ivector_reader.HasKey(utt)) {
167  KALDI_WARN << "No online iVector available for utterance " << utt;
168  num_fail++;
169  continue;
170  } else {
171  online_ivectors = &online_ivector_reader.Value(utt);
172  }
173  }
174 
175  DecodableAmNnetSimple nnet_decodable(
176  decodable_opts, trans_model, am_nnet,
177  features, ivector, online_ivectors,
178  online_ivector_period, &compiler);
179 
180  double like;
182  decoder, nnet_decodable, trans_model, word_syms, utt,
183  decodable_opts.acoustic_scale, determinize, allow_partial,
184  &alignment_writer, &words_writer, &compact_lattice_writer,
185  &lattice_writer,
186  &like)) {
187  tot_like += like;
188  frame_count += nnet_decodable.NumFramesReady();
189  num_success++;
190  } else num_fail++;
191  }
192  }
193 
194  kaldi::int64 input_frame_count =
195  frame_count * decodable_opts.frame_subsampling_factor;
196 
197  double elapsed = timer.Elapsed();
198  KALDI_LOG << "Time taken "<< elapsed
199  << "s: real-time factor assuming 100 frames/sec is "
200  << (elapsed * 100.0 / input_frame_count);
201  KALDI_LOG << "Done " << num_success << " utterances, failed for "
202  << num_fail;
203  KALDI_LOG << "Overall log-likelihood per frame is "
204  << (tot_like / frame_count) << " over "
205  << frame_count << " frames.";
206 
207  delete word_syms;
208  if (num_success != 0) return 0;
209  else return 1;
210  } catch(const std::exception &e) {
211  std::cerr << e.what();
212  return -1;
213  }
214 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
bool Open(const std::string &wspecifier)
void Reset()
Definition: timer.h:71
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
fst::StdArc StdArc
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
bool DecodeUtteranceLatticeFaster(LatticeFasterDecoderTpl< FST > &decoder, DecodableInterface &decodable, const TransitionModel &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, bool determinize, bool allow_partial, Int32VectorWriter *alignment_writer, Int32VectorWriter *words_writer, CompactLatticeWriter *compact_lattice_writer, LatticeWriter *lattice_writer, double *like_ptr)
This function DecodeUtteranceLatticeFaster is used in several decoders, and we have moved it here...
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
GrammarFst is an FST that is &#39;stitched together&#39; from multiple FSTs, that can recursively incorporate...
Definition: grammar-fst.h:96
#define KALDI_WARN
Definition: kaldi-error.h:150
This is the "normal" lattice-generating decoder.
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
Config class for the CollapseModel function.
Definition: nnet-utils.h:240