36 const fst::SymbolTable *word_syms,
41 std::string output_utterance_id, sentence;
44 while (decoder->
GetOutput(&output_utterance_id, &clat, &sentence)) {
45 if (word_syms != NULL)
46 std::cerr << output_utterance_id <<
' ' << sentence <<
'\n';
47 clat_writer->
Write(output_utterance_id, clat);
51 while (decoder->
GetOutput(&output_utterance_id, &lat, &sentence)) {
52 if (word_syms != NULL)
53 std::cerr << output_utterance_id <<
' ' << sentence <<
'\n';
54 lat_writer->
Write(output_utterance_id, lat);
61 int main(
int argc,
char *argv[]) {
66 using namespace kaldi;
69 using fst::SymbolTable;
74 "Generate lattices using nnet3 neural net model. This version is optimized\n" 75 "for GPU-based inference.\n" 76 "Usage: nnet3-latgen-faster-batch [options] <nnet-in> <fst-in> <features-rspecifier>" 77 " <lattice-wspecifier>\n";
80 bool allow_partial =
false;
83 std::string use_gpu =
"yes";
85 std::string word_syms_filename;
86 std::string ivector_rspecifier,
87 online_ivector_rspecifier,
89 int32 online_ivector_period = 0, num_threads = 1;
92 po.
Register(
"word-symbol-table", &word_syms_filename,
93 "Symbol table for words [for debug output]");
94 po.
Register(
"allow-partial", &allow_partial,
95 "If true, produce output even if end state was not reached.");
96 po.
Register(
"ivectors", &ivector_rspecifier,
"Rspecifier for " 97 "iVectors as vectors (i.e. not estimated online); per utterance " 98 "by default, or per speaker if you provide the --utt2spk option.");
99 po.
Register(
"online-ivectors", &online_ivector_rspecifier,
"Rspecifier for " 100 "iVectors estimated online, as matrices. If you supply this," 101 " you must set the --online-ivector-period option.");
102 po.
Register(
"online-ivector-period", &online_ivector_period,
"Number of frames " 103 "between iVectors in matrices supplied to the --online-ivectors " 105 po.
Register(
"num-threads", &num_threads,
"Number of decoder (i.e. " 106 "graph-search) threads. The number of model-evaluation threads " 107 "is always 1; this is optimized for use with the GPU.");
109 "yes|no|optional|wait, only has effect if compiled with CUDA");
112 CuDevice::RegisterDeviceOptions(&po);
123 CuDevice::Instantiate().AllowMultithreading();
124 CuDevice::Instantiate().SelectGpuId(use_gpu);
127 std::string model_in_rxfilename = po.
GetArg(1),
128 fst_in_rxfilename = po.
GetArg(2),
129 feature_rspecifier = po.
GetArg(3),
130 lattice_wspecifier = po.
GetArg(4);
136 Input ki(model_in_rxfilename, &binary);
147 if (! (determinize ? compact_lattice_writer.
Open(lattice_wspecifier)
148 : lattice_writer.
Open(lattice_wspecifier)))
149 KALDI_ERR <<
"Could not open table for writing lattices: " 150 << lattice_wspecifier;
153 online_ivector_rspecifier);
155 ivector_rspecifier, utt2spk_rspecifier);
157 fst::SymbolTable *word_syms = NULL;
158 if (word_syms_filename !=
"")
159 if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
160 KALDI_ERR <<
"Could not read symbol table from file " 161 << word_syms_filename;
173 trans_model, word_syms, allow_partial,
174 num_threads, &computer);
176 for (; !feature_reader.
Done(); feature_reader.
Next()) {
177 std::string utt = feature_reader.
Key();
180 if (features.NumRows() == 0) {
181 KALDI_WARN <<
"Zero-length utterance: " << utt;
182 decoder.UtteranceFailed();
187 if (!ivector_rspecifier.empty()) {
188 if (!ivector_reader.
HasKey(utt)) {
189 KALDI_WARN <<
"No iVector available for utterance " << utt;
190 decoder.UtteranceFailed();
193 ivector = &ivector_reader.
Value(utt);
196 if (!online_ivector_rspecifier.empty()) {
197 if (!online_ivector_reader.
HasKey(utt)) {
198 KALDI_WARN <<
"No online iVector available for utterance " << utt;
199 decoder.UtteranceFailed();
202 online_ivectors = &online_ivector_reader.
Value(utt);
206 decoder.AcceptInput(utt, features, ivector, online_ivectors,
207 online_ivector_period);
210 &compact_lattice_writer, &lattice_writer);
212 num_success = decoder.Finished();
214 &compact_lattice_writer, &lattice_writer);
223 CuDevice::Instantiate().PrintProfile();
226 return (num_success != 0 ? 0 : 1);
227 }
catch(
const std::exception &e) {
228 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Decoder object that uses multiple CPU threads for the graph search, plus a GPU for the neural net inf...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
bool Open(const std::string &wspecifier)
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This class is for when you are reading something in random access, but it may actually be stored per-...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
void HandleOutput(bool determinize, const fst::SymbolTable *word_syms, nnet3::NnetBatchDecoder *decoder, CompactLatticeWriter *clat_writer, LatticeWriter *lat_writer)
A templated class for writing objects to an archive or script file; see The Table concept...
const Nnet & GetNnet() const
void Write(const std::string &key, const T &value) const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
This file contains some miscellaneous functions dealing with class Nnet.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
fst::VectorFst< LatticeArc > Lattice
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
void Register(OptionsItf *po)
fst::VectorFst< CompactLatticeArc > CompactLattice
const VectorBase< BaseFloat > & Priors() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
const T & Value(const std::string &key)
int main(int argc, char *argv[])
This class does neural net inference in a way that is optimized for GPU use: it combines chunks of mu...
void Register(OptionsItf *opts)
bool GetOutput(std::string *utterance_id, CompactLattice *clat, std::string *sentence)
The user should call this to obtain output (This version should only be called if config...
Config class for the CollapseModel function.