29 int main(
int argc,
char *argv[]) {
31 using namespace kaldi;
34 typedef kaldi::int64 int64;
37 "Propagate the features through raw neural network model " 38 "and write the output. This version is optimized for GPU use. " 39 "If --apply-exp=true, apply the Exp() function to the output " 40 "before writing it out.\n" 42 "Usage: nnet3-compute-batch [options] <nnet-in> <features-rspecifier> " 43 "<matrix-wspecifier>\n" 44 " e.g.: nnet3-compute-batch final.raw scp:feats.scp " 45 "ark:nnet_prediction.ark\n";
53 bool apply_exp =
false, use_priors =
false;
54 std::string use_gpu =
"yes";
56 std::string word_syms_filename;
57 std::string ivector_rspecifier,
58 online_ivector_rspecifier,
60 int32 online_ivector_period = 0;
63 po.
Register(
"ivectors", &ivector_rspecifier,
"Rspecifier for " 64 "iVectors as vectors (i.e. not estimated online); per " 65 "utterance by default, or per speaker if you provide the " 67 po.
Register(
"utt2spk", &utt2spk_rspecifier,
"Rspecifier for " 68 "utt2spk option used to get ivectors per speaker");
69 po.
Register(
"online-ivectors", &online_ivector_rspecifier,
"Rspecifier for " 70 "iVectors estimated online, as matrices. If you supply this," 71 " you must set the --online-ivector-period option.");
72 po.
Register(
"online-ivector-period", &online_ivector_period,
"Number of " 73 "frames between iVectors in matrices supplied to the " 74 "--online-ivectors option");
75 po.
Register(
"apply-exp", &apply_exp,
"If true, apply exp function to " 78 "yes|no|optional|wait, only has effect if compiled with CUDA");
79 po.
Register(
"use-priors", &use_priors,
"If true, subtract the logs of the " 80 "priors stored with the model (in this case, " 81 "a .mdl file is expected as input).");
84 CuDevice::RegisterDeviceOptions(&po);
95 CuDevice::Instantiate().AllowMultithreading();
96 CuDevice::Instantiate().SelectGpuId(use_gpu);
99 std::string nnet_rxfilename = po.
GetArg(1),
100 feature_rspecifier = po.
GetArg(2),
101 matrix_wspecifier = po.
GetArg(3);
108 Input ki(nnet_rxfilename, &binary);
114 Nnet &nnet = (use_priors ? am_nnet.
GetNnet() : raw_nnet);
121 priors = am_nnet.
Priors();
124 online_ivector_rspecifier);
126 ivector_rspecifier, utt2spk_rspecifier);
130 int32 num_success = 0, num_fail = 0;
131 std::string output_uttid;
139 for (; !feature_reader.
Done(); feature_reader.
Next()) {
140 std::string utt = feature_reader.
Key();
143 KALDI_WARN <<
"Zero-length utterance: " << utt;
149 if (!ivector_rspecifier.empty()) {
150 if (!ivector_reader.
HasKey(utt)) {
151 KALDI_WARN <<
"No iVector available for utterance " << utt;
158 if (!online_ivector_rspecifier.empty()) {
159 if (!online_ivector_reader.
HasKey(utt)) {
160 KALDI_WARN <<
"No online iVector available for utterance " << utt;
165 online_ivector_reader.
Value(utt));
169 inference.
AcceptInput(utt, features, ivector, online_ivectors,
170 online_ivector_period);
172 std::string output_key;
174 while (inference.
GetOutput(&output_key, &output)) {
177 matrix_writer.
Write(output_key, output);
183 std::string output_key;
185 while (inference.
GetOutput(&output_key, &output)) {
188 matrix_writer.
Write(output_key, output);
192 CuDevice::Instantiate().PrintProfile();
194 double elapsed = timer.
Elapsed();
195 KALDI_LOG <<
"Time taken "<< elapsed <<
"s";
196 KALDI_LOG <<
"Done " << num_success <<
" utterances, failed for " 199 if (num_success != 0) {
204 }
catch(
const std::exception &e) {
205 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This class is for when you are reading something in random access, but it may actually be stored per-...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
A templated class for writing objects to an archive or script file; see The Table concept...
const Nnet & GetNnet() const
void AcceptInput(const std::string &utterance_id, const Matrix< BaseFloat > &input, const Vector< BaseFloat > *ivector, const Matrix< BaseFloat > *online_ivectors, int32 online_ivector_period)
The user should call this one by one for the utterances that this class needs to compute (intersperse...
void Write(const std::string &key, const T &value) const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
This file contains some miscellaneous functions dealing with class Nnet.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
This class implements a simplified interface to class NnetBatchComputer, which is suitable for progra...
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
void Finished()
The user should call this after the last input has been provided via AcceptInput().
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
void Register(OptionsItf *po)
const VectorBase< BaseFloat > & Priors() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
const T & Value(const std::string &key)
double Elapsed() const
Returns time in seconds.
bool GetOutput(std::string *utterance_id, Matrix< BaseFloat > *output)
The user should call this to obtain output.
Config class for the CollapseModel function.