66 using namespace kaldi;
69 using fst::SymbolTable;
74 "Generate lattices using nnet3 neural net model. This version is optimized\n" 75 "for GPU-based inference.\n" 76 "Usage: nnet3-latgen-faster-batch [options] <nnet-in> <fst-in> <features-rspecifier>" 77 " <lattice-wspecifier>\n";
80 bool allow_partial =
false;
83 std::string use_gpu =
"yes";
85 std::string word_syms_filename;
86 std::string ivector_rspecifier,
87 online_ivector_rspecifier,
89 int32 online_ivector_period = 0, num_threads = 1;
92 po.Register(
"word-symbol-table", &word_syms_filename,
93 "Symbol table for words [for debug output]");
94 po.Register(
"allow-partial", &allow_partial,
95 "If true, produce output even if end state was not reached.");
96 po.Register(
"ivectors", &ivector_rspecifier,
"Rspecifier for " 97 "iVectors as vectors (i.e. not estimated online); per utterance " 98 "by default, or per speaker if you provide the --utt2spk option.");
99 po.Register(
"online-ivectors", &online_ivector_rspecifier,
"Rspecifier for " 100 "iVectors estimated online, as matrices. If you supply this," 101 " you must set the --online-ivector-period option.");
102 po.Register(
"online-ivector-period", &online_ivector_period,
"Number of frames " 103 "between iVectors in matrices supplied to the --online-ivectors " 105 po.Register(
"num-threads", &num_threads,
"Number of decoder (i.e. " 106 "graph-search) threads. The number of model-evaluation threads " 107 "is always 1; this is optimized for use with the GPU.");
108 po.Register(
"use-gpu", &use_gpu,
109 "yes|no|optional|wait, only has effect if compiled with CUDA");
112 CuDevice::RegisterDeviceOptions(&po);
117 if (po.NumArgs() != 4) {
123 CuDevice::Instantiate().AllowMultithreading();
124 CuDevice::Instantiate().SelectGpuId(use_gpu);
127 std::string model_in_rxfilename = po.GetArg(1),
128 fst_in_rxfilename = po.GetArg(2),
129 feature_rspecifier = po.GetArg(3),
130 lattice_wspecifier = po.GetArg(4);
136 Input ki(model_in_rxfilename, &binary);
137 trans_model.
Read(ki.Stream(), binary);
138 am_nnet.
Read(ki.Stream(), binary);
147 if (! (determinize ? compact_lattice_writer.
Open(lattice_wspecifier)
148 : lattice_writer.
Open(lattice_wspecifier)))
149 KALDI_ERR <<
"Could not open table for writing lattices: " 150 << lattice_wspecifier;
153 online_ivector_rspecifier);
155 ivector_rspecifier, utt2spk_rspecifier);
157 fst::SymbolTable *word_syms = NULL;
158 if (word_syms_filename !=
"")
159 if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
160 KALDI_ERR <<
"Could not read symbol table from file " 161 << word_syms_filename;
173 trans_model, word_syms, allow_partial,
174 num_threads, &computer);
176 for (; !feature_reader.Done(); feature_reader.Next()) {
177 std::string utt = feature_reader.Key();
180 if (features.NumRows() == 0) {
181 KALDI_WARN <<
"Zero-length utterance: " << utt;
182 decoder.UtteranceFailed();
187 if (!ivector_rspecifier.empty()) {
188 if (!ivector_reader.HasKey(utt)) {
189 KALDI_WARN <<
"No iVector available for utterance " << utt;
190 decoder.UtteranceFailed();
193 ivector = &ivector_reader.Value(utt);
196 if (!online_ivector_rspecifier.empty()) {
197 if (!online_ivector_reader.HasKey(utt)) {
198 KALDI_WARN <<
"No online iVector available for utterance " << utt;
199 decoder.UtteranceFailed();
202 online_ivectors = &online_ivector_reader.Value(utt);
206 decoder.AcceptInput(utt, features, ivector, online_ivectors,
207 online_ivector_period);
210 &compact_lattice_writer, &lattice_writer);
212 num_success = decoder.Finished();
214 &compact_lattice_writer, &lattice_writer);
223 CuDevice::Instantiate().PrintProfile();
226 return (num_success != 0 ? 0 : 1);
227 }
catch(
const std::exception &e) {
228 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Decoder object that uses multiple CPU threads for the graph search, plus a GPU for the neural net inf...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
bool Open(const std::string &wspecifier)
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
This class is for when you are reading something in random access, but it may actually be stored per-...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
void HandleOutput(bool determinize, const fst::SymbolTable *word_syms, nnet3::NnetBatchDecoder *decoder, CompactLatticeWriter *clat_writer, LatticeWriter *lat_writer)
A templated class for writing objects to an archive or script file; see The Table concept...
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Register(OptionsItf *po)
const VectorBase< BaseFloat > & Priors() const
A class representing a vector.
This class does neural net inference in a way that is optimized for GPU use: it combines chunks of mu...
void Register(OptionsItf *opts)
Config class for the CollapseModel function.