31 using namespace kaldi;
34 typedef kaldi::int64 int64;
37 "Propagate the features through raw neural network model " 38 "and write the output. This version is optimized for GPU use. " 39 "If --apply-exp=true, apply the Exp() function to the output " 40 "before writing it out.\n" 42 "Usage: nnet3-compute-batch [options] <nnet-in> <features-rspecifier> " 43 "<matrix-wspecifier>\n" 44 " e.g.: nnet3-compute-batch final.raw scp:feats.scp " 45 "ark:nnet_prediction.ark\n";
53 bool apply_exp =
false, use_priors =
false;
54 std::string use_gpu =
"yes";
56 std::string word_syms_filename;
57 std::string ivector_rspecifier,
58 online_ivector_rspecifier,
60 int32 online_ivector_period = 0;
63 po.Register(
"ivectors", &ivector_rspecifier,
"Rspecifier for " 64 "iVectors as vectors (i.e. not estimated online); per " 65 "utterance by default, or per speaker if you provide the " 67 po.Register(
"utt2spk", &utt2spk_rspecifier,
"Rspecifier for " 68 "utt2spk option used to get ivectors per speaker");
69 po.Register(
"online-ivectors", &online_ivector_rspecifier,
"Rspecifier for " 70 "iVectors estimated online, as matrices. If you supply this," 71 " you must set the --online-ivector-period option.");
72 po.Register(
"online-ivector-period", &online_ivector_period,
"Number of " 73 "frames between iVectors in matrices supplied to the " 74 "--online-ivectors option");
75 po.Register(
"apply-exp", &apply_exp,
"If true, apply exp function to " 77 po.Register(
"use-gpu", &use_gpu,
78 "yes|no|optional|wait, only has effect if compiled with CUDA");
79 po.Register(
"use-priors", &use_priors,
"If true, subtract the logs of the " 80 "priors stored with the model (in this case, " 81 "a .mdl file is expected as input).");
84 CuDevice::RegisterDeviceOptions(&po);
89 if (po.NumArgs() != 3) {
95 CuDevice::Instantiate().AllowMultithreading();
96 CuDevice::Instantiate().SelectGpuId(use_gpu);
99 std::string nnet_rxfilename = po.GetArg(1),
100 feature_rspecifier = po.GetArg(2),
101 matrix_wspecifier = po.GetArg(3);
108 Input ki(nnet_rxfilename, &binary);
109 trans_model.
Read(ki.Stream(), binary);
110 am_nnet.
Read(ki.Stream(), binary);
114 Nnet &nnet = (use_priors ? am_nnet.
GetNnet() : raw_nnet);
121 priors = am_nnet.
Priors();
124 online_ivector_rspecifier);
126 ivector_rspecifier, utt2spk_rspecifier);
130 int32 num_success = 0, num_fail = 0;
131 std::string output_uttid;
139 for (; !feature_reader.Done(); feature_reader.Next()) {
140 std::string utt = feature_reader.Key();
143 KALDI_WARN <<
"Zero-length utterance: " << utt;
149 if (!ivector_rspecifier.empty()) {
150 if (!ivector_reader.HasKey(utt)) {
151 KALDI_WARN <<
"No iVector available for utterance " << utt;
158 if (!online_ivector_rspecifier.empty()) {
159 if (!online_ivector_reader.HasKey(utt)) {
160 KALDI_WARN <<
"No online iVector available for utterance " << utt;
165 online_ivector_reader.Value(utt));
169 inference.AcceptInput(utt, features, ivector, online_ivectors,
170 online_ivector_period);
172 std::string output_key;
174 while (inference.GetOutput(&output_key, &output)) {
177 matrix_writer.Write(output_key, output);
182 inference.Finished();
183 std::string output_key;
185 while (inference.GetOutput(&output_key, &output)) {
188 matrix_writer.Write(output_key, output);
192 CuDevice::Instantiate().PrintProfile();
194 double elapsed = timer.
Elapsed();
195 KALDI_LOG <<
"Time taken "<< elapsed <<
"s";
196 KALDI_LOG <<
"Done " << num_success <<
" utterances, failed for " 199 if (num_success != 0) {
204 }
catch(
const std::exception &e) {
205 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
This class is for when you are reading something in random access, but it may actually be stored per-...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
A templated class for writing objects to an archive or script file; see The Table concept...
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
This class implements a simplified interface to class NnetBatchComputer, which is suitable for progra...
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Register(OptionsItf *po)
const VectorBase< BaseFloat > & Priors() const
A class representing a vector.
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
double Elapsed() const
Returns time in seconds.
Config class for the CollapseModel function.