nnet3-latgen-faster-batch.cc File Reference
Include dependency graph for nnet3-latgen-faster-batch.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

void HandleOutput (bool determinize, const fst::SymbolTable *word_syms, nnet3::NnetBatchDecoder *decoder, CompactLatticeWriter *clat_writer, LatticeWriter *lat_writer)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 61 of file nnet3-latgen-faster-batch.cc.

References kaldi::nnet3::CollapseModel(), LatticeFasterDecoderConfig::determinize_lattice, SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), AmNnetSimple::GetNnet(), kaldi::HandleOutput(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), KALDI_ERR, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), TableWriter< Holder >::Open(), ParseOptions::PrintUsage(), AmNnetSimple::Priors(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), fst::ReadFstKaldiGeneric(), LatticeFasterDecoderConfig::Register(), ParseOptions::Register(), NnetBatchComputerOptions::Register(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and RandomAccessTableReaderMapped< Holder >::Value().

61  {
62  // note: making this program work with GPUs is as simple as initializing the
63  // device, but it probably won't make a huge difference in speed for typical
64  // setups.
65  try {
66  using namespace kaldi;
67  using namespace kaldi::nnet3;
68  typedef kaldi::int32 int32;
69  using fst::SymbolTable;
70  using fst::Fst;
71  using fst::StdArc;
72 
73  const char *usage =
74  "Generate lattices using nnet3 neural net model. This version is optimized\n"
75  "for GPU-based inference.\n"
76  "Usage: nnet3-latgen-faster-batch [options] <nnet-in> <fst-in> <features-rspecifier>"
77  " <lattice-wspecifier>\n";
78  ParseOptions po(usage);
79 
80  bool allow_partial = false;
81  LatticeFasterDecoderConfig decoder_opts;
82  NnetBatchComputerOptions compute_opts;
83  std::string use_gpu = "yes";
84 
85  std::string word_syms_filename;
86  std::string ivector_rspecifier,
87  online_ivector_rspecifier,
88  utt2spk_rspecifier;
89  int32 online_ivector_period = 0, num_threads = 1;
90  decoder_opts.Register(&po);
91  compute_opts.Register(&po);
92  po.Register("word-symbol-table", &word_syms_filename,
93  "Symbol table for words [for debug output]");
94  po.Register("allow-partial", &allow_partial,
95  "If true, produce output even if end state was not reached.");
96  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
97  "iVectors as vectors (i.e. not estimated online); per utterance "
98  "by default, or per speaker if you provide the --utt2spk option.");
99  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
100  "iVectors estimated online, as matrices. If you supply this,"
101  " you must set the --online-ivector-period option.");
102  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
103  "between iVectors in matrices supplied to the --online-ivectors "
104  "option");
105  po.Register("num-threads", &num_threads, "Number of decoder (i.e. "
106  "graph-search) threads. The number of model-evaluation threads "
107  "is always 1; this is optimized for use with the GPU.");
108  po.Register("use-gpu", &use_gpu,
109  "yes|no|optional|wait, only has effect if compiled with CUDA");
110 
111 #if HAVE_CUDA==1
112  CuDevice::RegisterDeviceOptions(&po);
113 #endif
114 
115  po.Read(argc, argv);
116 
117  if (po.NumArgs() != 4) {
118  po.PrintUsage();
119  exit(1);
120  }
121 
122 #if HAVE_CUDA==1
123  CuDevice::Instantiate().AllowMultithreading();
124  CuDevice::Instantiate().SelectGpuId(use_gpu);
125 #endif
126 
127  std::string model_in_rxfilename = po.GetArg(1),
128  fst_in_rxfilename = po.GetArg(2),
129  feature_rspecifier = po.GetArg(3),
130  lattice_wspecifier = po.GetArg(4);
131 
132  TransitionModel trans_model;
133  AmNnetSimple am_nnet;
134  {
135  bool binary;
136  Input ki(model_in_rxfilename, &binary);
137  trans_model.Read(ki.Stream(), binary);
138  am_nnet.Read(ki.Stream(), binary);
139  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
140  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
141  CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));
142  }
143 
144  bool determinize = decoder_opts.determinize_lattice;
145  CompactLatticeWriter compact_lattice_writer;
146  LatticeWriter lattice_writer;
147  if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
148  : lattice_writer.Open(lattice_wspecifier)))
149  KALDI_ERR << "Could not open table for writing lattices: "
150  << lattice_wspecifier;
151 
152  RandomAccessBaseFloatMatrixReader online_ivector_reader(
153  online_ivector_rspecifier);
155  ivector_rspecifier, utt2spk_rspecifier);
156 
157  fst::SymbolTable *word_syms = NULL;
158  if (word_syms_filename != "")
159  if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
160  KALDI_ERR << "Could not read symbol table from file "
161  << word_syms_filename;
162 
163 
164  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
165 
166  Fst<StdArc> *decode_fst = fst::ReadFstKaldiGeneric(fst_in_rxfilename);
167 
168  int32 num_success;
169  {
170  NnetBatchComputer computer(compute_opts, am_nnet.GetNnet(),
171  am_nnet.Priors());
172  NnetBatchDecoder decoder(*decode_fst, decoder_opts,
173  trans_model, word_syms, allow_partial,
174  num_threads, &computer);
175 
176  for (; !feature_reader.Done(); feature_reader.Next()) {
177  std::string utt = feature_reader.Key();
178  const Matrix<BaseFloat> &features (feature_reader.Value());
179 
180  if (features.NumRows() == 0) {
181  KALDI_WARN << "Zero-length utterance: " << utt;
182  decoder.UtteranceFailed();
183  continue;
184  }
185  const Matrix<BaseFloat> *online_ivectors = NULL;
186  const Vector<BaseFloat> *ivector = NULL;
187  if (!ivector_rspecifier.empty()) {
188  if (!ivector_reader.HasKey(utt)) {
189  KALDI_WARN << "No iVector available for utterance " << utt;
190  decoder.UtteranceFailed();
191  continue;
192  } else {
193  ivector = &ivector_reader.Value(utt);
194  }
195  }
196  if (!online_ivector_rspecifier.empty()) {
197  if (!online_ivector_reader.HasKey(utt)) {
198  KALDI_WARN << "No online iVector available for utterance " << utt;
199  decoder.UtteranceFailed();
200  continue;
201  } else {
202  online_ivectors = &online_ivector_reader.Value(utt);
203  }
204  }
205 
206  decoder.AcceptInput(utt, features, ivector, online_ivectors,
207  online_ivector_period);
208 
209  HandleOutput(decoder_opts.determinize_lattice, word_syms, &decoder,
210  &compact_lattice_writer, &lattice_writer);
211  }
212  num_success = decoder.Finished();
213  HandleOutput(decoder_opts.determinize_lattice, word_syms, &decoder,
214  &compact_lattice_writer, &lattice_writer);
215 
216  // At this point the decoder and batch-computer objects will print
217  // diagnostics from their destructors (they are going out of scope).
218  }
219  delete decode_fst;
220  delete word_syms;
221 
222 #if HAVE_CUDA==1
223  CuDevice::Instantiate().PrintProfile();
224 #endif
225 
226  return (num_success != 0 ? 0 : 1);
227  } catch(const std::exception &e) {
228  std::cerr << e.what();
229  return -1;
230  }
231 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Decoder object that uses multiple CPU threads for the graph search, plus a GPU for the neural net inf...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
bool Open(const std::string &wspecifier)
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
fst::StdArc StdArc
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
void HandleOutput(bool determinize, const fst::SymbolTable *word_syms, nnet3::NnetBatchDecoder *decoder, CompactLatticeWriter *clat_writer, LatticeWriter *lat_writer)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
const VectorBase< BaseFloat > & Priors() const
A class representing a vector.
Definition: kaldi-vector.h:406
This class does neural net inference in a way that is optimized for GPU use: it combines chunks of mu...
Config class for the CollapseModel function.
Definition: nnet-utils.h:240