28 using namespace kaldi;
31 typedef kaldi::int64 int64;
34 "Simulates the online neural net computation for each file of input\n" 35 "features, and outputs as a matrix the result, with optional\n" 36 "iVector-based speaker adaptation. Note: some configuration values\n" 37 "and inputs are set via config files whose filenames are passed as\n" 38 "options. Used mostly for debugging.\n" 39 "Note: if you want it to apply a log (e.g. for log-likelihoods), use\n" 42 "Usage: online2-wav-nnet2-am-compute [options] <nnet-in>\n" 43 "<spk2utt-rspecifier> <wav-rspecifier> <feature-or-loglikes-wspecifier>\n" 44 "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n" 45 "you want to compute utterance by utterance.\n";
48 bool apply_log =
false;
49 bool pad_input =
true;
56 po.Register(
"apply-log", &apply_log,
"Apply a log to the result of the computation " 57 "before outputting.");
58 po.Register(
"pad-input", &pad_input,
"If true, duplicate the first and last frames " 59 "of input features as required for temporal context, to prevent #frames " 60 "of output being less than those of input.");
61 po.Register(
"chunk-length", &chunk_length_secs,
62 "Length of chunk size in seconds, that we process.");
63 po.Register(
"online", &online,
64 "You can set this to false to disable online iVector estimation " 65 "and have all the data for each utterance used, even at " 66 "utterance start. This is useful where you just want the best " 67 "results and don't care about online operation. Setting this to " 68 "false has the same effect as setting " 69 "--use-most-recent-ivector=true and --greedy-ivector-extractor=true " 70 "in the file given to --ivector-extraction-config, and " 71 "--chunk-length=-1.");
75 if (po.NumArgs() != 4) {
80 std::string nnet2_rxfilename = po.GetArg(1),
81 spk2utt_rspecifier = po.GetArg(2),
82 wav_rspecifier = po.GetArg(3),
83 features_or_loglikes_wspecifier = po.GetArg(4);
87 feature_info.ivector_extractor_info.use_most_recent_ivector =
true;
88 feature_info.ivector_extractor_info.greedy_ivector_extractor =
true;
89 chunk_length_secs = -1.0;
93 if (feature_info.global_cmvn_stats_rxfilename !=
"")
101 Input ki(nnet2_rxfilename, &binary);
102 trans_model.
Read(ki.Stream(), binary);
103 am_nnet.
Read(ki.Stream(), binary);
107 int64 num_done = 0, num_frames = 0;
112 for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
113 std::string spk = spk2utt_reader.Key();
114 const std::vector<std::string> &uttlist = spk2utt_reader.Value();
117 feature_info.ivector_extractor_info);
120 for (
size_t i = 0;
i < uttlist.size();
i++) {
121 std::string utt = uttlist[
i];
122 if (!wav_reader.HasKey(utt)) {
123 KALDI_WARN <<
"Did not find audio for utterance " << utt;
126 const WaveData &wave_data = wav_reader.Value(utt);
132 feature_pipeline.SetAdaptationState(adaptation_state);
133 feature_pipeline.SetCmvnState(cmvn_state);
137 if (chunk_length_secs > 0) {
138 chunk_length =
int32(samp_freq * chunk_length_secs);
139 if (chunk_length == 0) chunk_length = 1;
141 chunk_length = std::numeric_limits<int32>::max();
144 int32 samp_offset = 0;
145 while (samp_offset < data.Dim()) {
146 int32 samp_remaining = data.Dim() - samp_offset;
147 int32 num_samp = chunk_length < samp_remaining ? chunk_length
151 feature_pipeline.AcceptWaveform(samp_freq, wave_part);
153 samp_offset += num_samp;
154 if (samp_offset == data.Dim()) {
156 feature_pipeline.InputFinished();
160 int32 feats_num_frames = feature_pipeline.NumFramesReady(),
161 feats_dim = feature_pipeline.Dim();
164 for (int32
i = 0;
i < feats_num_frames;
i++) {
166 feature_pipeline.GetFrame(
i, &frame_vector);
171 feature_pipeline.GetAdaptationState(&adaptation_state);
172 feature_pipeline.GetCmvnState(&cmvn_state);
174 int32 output_frames = feats.NumRows(),
181 if (output_frames <= 0) {
182 KALDI_WARN <<
"Skipping utterance " << utt <<
" because output " 183 <<
"would be empty.";
190 output.ApplyFloor(1.0e-20);
194 writer.Write(utt, output);
195 num_frames += feats.NumRows();
198 KALDI_LOG <<
"Processed data for utterance " << utt;
202 KALDI_LOG <<
"Processed " << num_done <<
" feature files, " 203 << num_frames <<
" frames of input were processed.";
205 return (num_done != 0 ? 0 : 1);
206 }
catch(
const std::exception& e) {
207 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
int32 LeftContext() const
Returns the left-context summed over all the Components...
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
void Read(std::istream &is, bool binary)
int32 OutputDim() const
The output dimension of the network – typically the number of pdfs.
A templated class for writing objects to an archive or script file; see The Table concept...
BaseFloat SampFreq() const
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
const Matrix< BaseFloat > & Data() const
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
int32 RightContext() const
Returns the right-context summed over all the Components...
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Register(OptionsItf *opts)
This class's purpose is to read in Wave files.
OnlineNnet2FeaturePipeline is a class that's responsible for putting together the various parts of th...
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
const Nnet & GetNnet() const