nnet3-latgen-grammar.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-latgen-grammar.cc
2 
3 // Copyright 2018 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "tree/context-dep.h"
24 #include "hmm/transition-model.h"
25 #include "fstext/fstext-lib.h"
28 #include "nnet3/nnet-utils.h"
29 #include "decoder/grammar-fst.h"
30 #include "base/timer.h"
31 
32 
33 int main(int argc, char *argv[]) {
34  // note: making this program work with GPUs is as simple as initializing the
35  // device, but it probably won't make a huge difference in speed for typical
36  // setups.
37  try {
38  using namespace kaldi;
39  using namespace kaldi::nnet3;
40  typedef kaldi::int32 int32;
41  using fst::SymbolTable;
42  using fst::Fst;
43  using fst::StdArc;
44 
45  const char *usage =
46  "Generate lattices using nnet3 neural net model, and GrammarFst-based graph\n"
47  "see kaldi-asr.org/doc/grammar.html for more context.\n"
48  "\n"
49  "Usage: nnet3-latgen-grammar [options] <nnet-in> <grammar-fst-in> <features-rspecifier>"
50  " <lattice-wspecifier> [ <words-wspecifier> [<alignments-wspecifier>] ]\n";
51 
52  ParseOptions po(usage);
53  Timer timer;
54  bool allow_partial = false;
56  NnetSimpleComputationOptions decodable_opts;
57 
58  std::string word_syms_filename;
59  std::string ivector_rspecifier,
60  online_ivector_rspecifier,
61  utt2spk_rspecifier;
62  int32 online_ivector_period = 0;
63  config.Register(&po);
64  decodable_opts.Register(&po);
65  po.Register("word-symbol-table", &word_syms_filename,
66  "Symbol table for words [for debug output]");
67  po.Register("allow-partial", &allow_partial,
68  "If true, produce output even if end state was not reached.");
69  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
70  "iVectors as vectors (i.e. not estimated online); per utterance "
71  "by default, or per speaker if you provide the --utt2spk option.");
72  po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for "
73  "utt2spk option used to get ivectors per speaker");
74  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
75  "iVectors estimated online, as matrices. If you supply this,"
76  " you must set the --online-ivector-period option.");
77  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
78  "between iVectors in matrices supplied to the --online-ivectors "
79  "option");
80 
81  po.Read(argc, argv);
82 
83  if (po.NumArgs() < 4 || po.NumArgs() > 6) {
84  po.PrintUsage();
85  exit(1);
86  }
87 
88  std::string model_rxfilename = po.GetArg(1),
89  grammar_fst_rxfilename = po.GetArg(2),
90  feature_rspecifier = po.GetArg(3),
91  lattice_wspecifier = po.GetArg(4),
92  words_wspecifier = po.GetOptArg(5),
93  alignment_wspecifier = po.GetOptArg(6);
94 
95  TransitionModel trans_model;
96  AmNnetSimple am_nnet;
97  {
98  bool binary;
99  Input ki(model_rxfilename, &binary);
100  trans_model.Read(ki.Stream(), binary);
101  am_nnet.Read(ki.Stream(), binary);
102  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
103  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
104  CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));
105  }
106 
107  bool determinize = config.determinize_lattice;
108  CompactLatticeWriter compact_lattice_writer;
109  LatticeWriter lattice_writer;
110  if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
111  : lattice_writer.Open(lattice_wspecifier)))
112  KALDI_ERR << "Could not open table for writing lattices: "
113  << lattice_wspecifier;
114 
115  RandomAccessBaseFloatMatrixReader online_ivector_reader(
116  online_ivector_rspecifier);
118  ivector_rspecifier, utt2spk_rspecifier);
119 
120  Int32VectorWriter words_writer(words_wspecifier);
121  Int32VectorWriter alignment_writer(alignment_wspecifier);
122 
123  fst::SymbolTable *word_syms = NULL;
124  if (word_syms_filename != "")
125  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
126  KALDI_ERR << "Could not read symbol table from file "
127  << word_syms_filename;
128 
129  double tot_like = 0.0;
130  kaldi::int64 frame_count = 0;
131  int num_success = 0, num_fail = 0;
132  // this compiler object allows caching of computations across
133  // different utterances.
134  CachingOptimizingCompiler compiler(am_nnet.GetNnet(),
135  decodable_opts.optimize_config);
136 
137  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
138 
140  ReadKaldiObject(grammar_fst_rxfilename, &fst);
141  timer.Reset();
142 
143  {
144  LatticeFasterDecoderTpl<fst::GrammarFst> decoder(fst, config);
145 
146  for (; !feature_reader.Done(); feature_reader.Next()) {
147  std::string utt = feature_reader.Key();
148  const Matrix<BaseFloat> &features (feature_reader.Value());
149  if (features.NumRows() == 0) {
150  KALDI_WARN << "Zero-length utterance: " << utt;
151  num_fail++;
152  continue;
153  }
154  const Matrix<BaseFloat> *online_ivectors = NULL;
155  const Vector<BaseFloat> *ivector = NULL;
156  if (!ivector_rspecifier.empty()) {
157  if (!ivector_reader.HasKey(utt)) {
158  KALDI_WARN << "No iVector available for utterance " << utt;
159  num_fail++;
160  continue;
161  } else {
162  ivector = &ivector_reader.Value(utt);
163  }
164  }
165  if (!online_ivector_rspecifier.empty()) {
166  if (!online_ivector_reader.HasKey(utt)) {
167  KALDI_WARN << "No online iVector available for utterance " << utt;
168  num_fail++;
169  continue;
170  } else {
171  online_ivectors = &online_ivector_reader.Value(utt);
172  }
173  }
174 
175  DecodableAmNnetSimple nnet_decodable(
176  decodable_opts, trans_model, am_nnet,
177  features, ivector, online_ivectors,
178  online_ivector_period, &compiler);
179 
180  double like;
182  decoder, nnet_decodable, trans_model, word_syms, utt,
183  decodable_opts.acoustic_scale, determinize, allow_partial,
184  &alignment_writer, &words_writer, &compact_lattice_writer,
185  &lattice_writer,
186  &like)) {
187  tot_like += like;
188  frame_count += nnet_decodable.NumFramesReady();
189  num_success++;
190  } else num_fail++;
191  }
192  }
193 
194  kaldi::int64 input_frame_count =
195  frame_count * decodable_opts.frame_subsampling_factor;
196 
197  double elapsed = timer.Elapsed();
198  KALDI_LOG << "Time taken "<< elapsed
199  << "s: real-time factor assuming 100 frames/sec is "
200  << (elapsed * 100.0 / input_frame_count);
201  KALDI_LOG << "Done " << num_success << " utterances, failed for "
202  << num_fail;
203  KALDI_LOG << "Overall log-likelihood per frame is "
204  << (tot_like / frame_count) << " over "
205  << frame_count << " frames.";
206 
207  delete word_syms;
208  if (num_success != 0) return 0;
209  else return 1;
210  } catch(const std::exception &e) {
211  std::cerr << e.what();
212  return -1;
213  }
214 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
bool Open(const std::string &wspecifier)
void Reset()
Definition: timer.h:71
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
bool DecodeUtteranceLatticeFaster(LatticeFasterDecoderTpl< FST > &decoder, DecodableInterface &decodable, const TransitionModel &trans_model, const fst::SymbolTable *word_syms, std::string utt, double acoustic_scale, bool determinize, bool allow_partial, Int32VectorWriter *alignment_writer, Int32VectorWriter *words_writer, CompactLatticeWriter *compact_lattice_writer, LatticeWriter *lattice_writer, double *like_ptr)
This function DecodeUtteranceLatticeFaster is used in several decoders, and we have moved it here...
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
This file contains some miscellaneous functions dealing with class Nnet.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::istream & Stream()
Definition: kaldi-io.cc:826
int main(int argc, char *argv[])
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
GrammarFst is an FST that is &#39;stitched together&#39; from multiple FSTs, that can recursively incorporate...
Definition: grammar-fst.h:96
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
This is the "normal" lattice-generating decoder.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
const T & Value(const std::string &key)
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
std::string GetOptArg(int param) const
Config class for the CollapseModel function.
Definition: nnet-utils.h:240