online2-wav-nnet2-am-compute.cc
Go to the documentation of this file.
1 // online2bin/online2-wav-nnet2-am-compute.cc
2 
3 // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 // 2014 David Snyder
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "feat/wave-reader.h"
24 #include "online2/onlinebin-util.h"
25 
26 int main(int argc, char *argv[]) {
27  try {
28  using namespace kaldi;
29  using namespace kaldi::nnet2;
30  typedef kaldi::int32 int32;
31  typedef kaldi::int64 int64;
32 
33  const char *usage =
34  "Simulates the online neural net computation for each file of input\n"
35  "features, and outputs as a matrix the result, with optional\n"
36  "iVector-based speaker adaptation. Note: some configuration values\n"
37  "and inputs are set via config files whose filenames are passed as\n"
38  "options. Used mostly for debugging.\n"
39  "Note: if you want it to apply a log (e.g. for log-likelihoods), use\n"
40  "--apply-log=true.\n"
41  "\n"
42  "Usage: online2-wav-nnet2-am-compute [options] <nnet-in>\n"
43  "<spk2utt-rspecifier> <wav-rspecifier> <feature-or-loglikes-wspecifier>\n"
44  "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
45  "you want to compute utterance by utterance.\n";
46 
47  BaseFloat chunk_length_secs = 0.05;
48  bool apply_log = false;
49  bool pad_input = true;
50  bool online = true;
51 
52  // feature_config includes configuration for the iVector adaptation,
53  // as well as the basic features.
54  OnlineNnet2FeaturePipelineConfig feature_config;
55  ParseOptions po(usage);
56  po.Register("apply-log", &apply_log, "Apply a log to the result of the computation "
57  "before outputting.");
58  po.Register("pad-input", &pad_input, "If true, duplicate the first and last frames "
59  "of input features as required for temporal context, to prevent #frames "
60  "of output being less than those of input.");
61  po.Register("chunk-length", &chunk_length_secs,
62  "Length of chunk size in seconds, that we process.");
63  po.Register("online", &online,
64  "You can set this to false to disable online iVector estimation "
65  "and have all the data for each utterance used, even at "
66  "utterance start. This is useful where you just want the best "
67  "results and don't care about online operation. Setting this to "
68  "false has the same effect as setting "
69  "--use-most-recent-ivector=true and --greedy-ivector-extractor=true "
70  "in the file given to --ivector-extraction-config, and "
71  "--chunk-length=-1.");
72 
73  feature_config.Register(&po);
74  po.Read(argc, argv);
75  if (po.NumArgs() != 4) {
76  po.PrintUsage();
77  return 1;
78  }
79 
80  std::string nnet2_rxfilename = po.GetArg(1),
81  spk2utt_rspecifier = po.GetArg(2),
82  wav_rspecifier = po.GetArg(3),
83  features_or_loglikes_wspecifier = po.GetArg(4);
84 
85  OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
86  if (!online) {
89  chunk_length_secs = -1.0;
90  }
91 
92  Matrix<double> global_cmvn_stats;
93  if (feature_info.global_cmvn_stats_rxfilename != "")
95  &global_cmvn_stats);
96 
97  TransitionModel trans_model;
98  AmNnet am_nnet;
99  {
100  bool binary;
101  Input ki(nnet2_rxfilename, &binary);
102  trans_model.Read(ki.Stream(), binary);
103  am_nnet.Read(ki.Stream(), binary);
104  }
105  Nnet &nnet = am_nnet.GetNnet();
106 
107  int64 num_done = 0, num_frames = 0;
108  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
109  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
110  BaseFloatCuMatrixWriter writer(features_or_loglikes_wspecifier);
111 
112  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
113  std::string spk = spk2utt_reader.Key();
114  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
115 
116  OnlineIvectorExtractorAdaptationState adaptation_state(
117  feature_info.ivector_extractor_info);
118  OnlineCmvnState cmvn_state(global_cmvn_stats);
119 
120  for (size_t i = 0; i < uttlist.size(); i++) {
121  std::string utt = uttlist[i];
122  if (!wav_reader.HasKey(utt)) {
123  KALDI_WARN << "Did not find audio for utterance " << utt;
124  continue;
125  }
126  const WaveData &wave_data = wav_reader.Value(utt);
127  // get the data for channel zero (if the signal is not mono, we only
128  // take the first channel).
129  SubVector<BaseFloat> data(wave_data.Data(), 0);
130 
131  OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
132  feature_pipeline.SetAdaptationState(adaptation_state);
133  feature_pipeline.SetCmvnState(cmvn_state);
134 
135  BaseFloat samp_freq = wave_data.SampFreq();
136  int32 chunk_length;
137  if (chunk_length_secs > 0) {
138  chunk_length = int32(samp_freq * chunk_length_secs);
139  if (chunk_length == 0) chunk_length = 1;
140  } else {
141  chunk_length = std::numeric_limits<int32>::max();
142  }
143 
144  int32 samp_offset = 0;
145  while (samp_offset < data.Dim()) {
146  int32 samp_remaining = data.Dim() - samp_offset;
147  int32 num_samp = chunk_length < samp_remaining ? chunk_length
148  : samp_remaining;
149 
150  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
151  feature_pipeline.AcceptWaveform(samp_freq, wave_part);
152 
153  samp_offset += num_samp;
154  if (samp_offset == data.Dim()) {
155  // no more input. flush out last frames
156  feature_pipeline.InputFinished();
157  }
158  }
159 
160  int32 feats_num_frames = feature_pipeline.NumFramesReady(),
161  feats_dim = feature_pipeline.Dim();
162  Matrix<BaseFloat> feats(feats_num_frames, feats_dim);
163 
164  for (int32 i = 0; i < feats_num_frames; i++) {
165  SubVector<BaseFloat> frame_vector(feats, i);
166  feature_pipeline.GetFrame(i, &frame_vector);
167  }
168 
169  // In an application you might avoid updating the adaptation state if
170  // you felt the utterance had low confidence. See lat/confidence.h
171  feature_pipeline.GetAdaptationState(&adaptation_state);
172  feature_pipeline.GetCmvnState(&cmvn_state);
173 
174  int32 output_frames = feats.NumRows(),
175  output_dim = nnet.OutputDim();
176  CuMatrix<BaseFloat> output(output_frames, output_dim),
177  feats_cu(feats);
178 
179  if (!pad_input)
180  output_frames -= nnet.LeftContext() + nnet.RightContext();
181  if (output_frames <= 0) {
182  KALDI_WARN << "Skipping utterance " << utt << " because output "
183  << "would be empty.";
184  continue;
185  }
186 
187  NnetComputation(nnet, feats_cu, pad_input, &output);
188 
189  if (apply_log) {
190  output.ApplyFloor(1.0e-20);
191  output.ApplyLog();
192  }
193 
194  writer.Write(utt, output);
195  num_frames += feats.NumRows();
196  num_done++;
197 
198  KALDI_LOG << "Processed data for utterance " << utt;
199  }
200  }
201 
202  KALDI_LOG << "Processed " << num_done << " feature files, "
203  << num_frames << " frames of input were processed.";
204 
205  return (num_done != 0 ? 0 : 1);
206  } catch(const std::exception& e) {
207  std::cerr << e.what() << '\n';
208  return -1;
209  }
210 } // main()
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 LeftContext() const
Returns the left-context summed over all the Components...
Definition: nnet-nnet.cc:42
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
int main(int argc, char *argv[])
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
int32 OutputDim() const
The output dimension of the network – typically the number of pdfs.
Definition: nnet-nnet.cc:31
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
This file contains a different version of the feature-extraction pipeline in online-feature-pipeline...
void Write(const std::string &key, const T &value) const
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
int32 RightContext() const
Returns the right-context summed over all the Components...
Definition: nnet-nnet.cc:56
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool HasKey(const std::string &key)
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
int NumArgs() const
Number of positional parameters (c.f. argc-1).
std::string global_cmvn_stats_rxfilename
Options for online cmvn, read from config file.
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
const Nnet & GetNnet() const
Definition: am-nnet.h:61