online2-wav-nnet2-am-compute.cc File Reference
Include dependency graph for online2-wav-nnet2-am-compute.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 26 of file online2-wav-nnet2-am-compute.cc.

References WaveData::Data(), VectorBase< Real >::Dim(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), AmNnet::GetNnet(), OnlineNnet2FeaturePipelineInfo::global_cmvn_stats_rxfilename, OnlineIvectorExtractionInfo::greedy_ivector_extractor, RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, OnlineNnet2FeaturePipelineInfo::ivector_extractor_info, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), Nnet::LeftContext(), SequentialTableReader< Holder >::Next(), kaldi::nnet2::NnetComputation(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), Nnet::OutputDim(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), OnlineNnet2FeaturePipelineConfig::Register(), Nnet::RightContext(), WaveData::SampFreq(), Input::Stream(), OnlineIvectorExtractionInfo::use_most_recent_ivector, RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

26  {
27  try {
28  using namespace kaldi;
29  using namespace kaldi::nnet2;
30  typedef kaldi::int32 int32;
31  typedef kaldi::int64 int64;
32 
33  const char *usage =
34  "Simulates the online neural net computation for each file of input\n"
35  "features, and outputs as a matrix the result, with optional\n"
36  "iVector-based speaker adaptation. Note: some configuration values\n"
37  "and inputs are set via config files whose filenames are passed as\n"
38  "options. Used mostly for debugging.\n"
39  "Note: if you want it to apply a log (e.g. for log-likelihoods), use\n"
40  "--apply-log=true.\n"
41  "\n"
42  "Usage: online2-wav-nnet2-am-compute [options] <nnet-in>\n"
43  "<spk2utt-rspecifier> <wav-rspecifier> <feature-or-loglikes-wspecifier>\n"
44  "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
45  "you want to compute utterance by utterance.\n";
46 
47  BaseFloat chunk_length_secs = 0.05;
48  bool apply_log = false;
49  bool pad_input = true;
50  bool online = true;
51 
52  // feature_config includes configuration for the iVector adaptation,
53  // as well as the basic features.
54  OnlineNnet2FeaturePipelineConfig feature_config;
55  ParseOptions po(usage);
56  po.Register("apply-log", &apply_log, "Apply a log to the result of the computation "
57  "before outputting.");
58  po.Register("pad-input", &pad_input, "If true, duplicate the first and last frames "
59  "of input features as required for temporal context, to prevent #frames "
60  "of output being less than those of input.");
61  po.Register("chunk-length", &chunk_length_secs,
62  "Length of chunk size in seconds, that we process.");
63  po.Register("online", &online,
64  "You can set this to false to disable online iVector estimation "
65  "and have all the data for each utterance used, even at "
66  "utterance start. This is useful where you just want the best "
67  "results and don't care about online operation. Setting this to "
68  "false has the same effect as setting "
69  "--use-most-recent-ivector=true and --greedy-ivector-extractor=true "
70  "in the file given to --ivector-extraction-config, and "
71  "--chunk-length=-1.");
72 
73  feature_config.Register(&po);
74  po.Read(argc, argv);
75  if (po.NumArgs() != 4) {
76  po.PrintUsage();
77  return 1;
78  }
79 
80  std::string nnet2_rxfilename = po.GetArg(1),
81  spk2utt_rspecifier = po.GetArg(2),
82  wav_rspecifier = po.GetArg(3),
83  features_or_loglikes_wspecifier = po.GetArg(4);
84 
85  OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
86  if (!online) {
87  feature_info.ivector_extractor_info.use_most_recent_ivector = true;
88  feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
89  chunk_length_secs = -1.0;
90  }
91 
92  Matrix<double> global_cmvn_stats;
93  if (feature_info.global_cmvn_stats_rxfilename != "")
94  ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
95  &global_cmvn_stats);
96 
97  TransitionModel trans_model;
98  AmNnet am_nnet;
99  {
100  bool binary;
101  Input ki(nnet2_rxfilename, &binary);
102  trans_model.Read(ki.Stream(), binary);
103  am_nnet.Read(ki.Stream(), binary);
104  }
105  Nnet &nnet = am_nnet.GetNnet();
106 
107  int64 num_done = 0, num_frames = 0;
108  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
109  RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
110  BaseFloatCuMatrixWriter writer(features_or_loglikes_wspecifier);
111 
112  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
113  std::string spk = spk2utt_reader.Key();
114  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
115 
116  OnlineIvectorExtractorAdaptationState adaptation_state(
117  feature_info.ivector_extractor_info);
118  OnlineCmvnState cmvn_state(global_cmvn_stats);
119 
120  for (size_t i = 0; i < uttlist.size(); i++) {
121  std::string utt = uttlist[i];
122  if (!wav_reader.HasKey(utt)) {
123  KALDI_WARN << "Did not find audio for utterance " << utt;
124  continue;
125  }
126  const WaveData &wave_data = wav_reader.Value(utt);
127  // get the data for channel zero (if the signal is not mono, we only
128  // take the first channel).
129  SubVector<BaseFloat> data(wave_data.Data(), 0);
130 
131  OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
132  feature_pipeline.SetAdaptationState(adaptation_state);
133  feature_pipeline.SetCmvnState(cmvn_state);
134 
135  BaseFloat samp_freq = wave_data.SampFreq();
136  int32 chunk_length;
137  if (chunk_length_secs > 0) {
138  chunk_length = int32(samp_freq * chunk_length_secs);
139  if (chunk_length == 0) chunk_length = 1;
140  } else {
141  chunk_length = std::numeric_limits<int32>::max();
142  }
143 
144  int32 samp_offset = 0;
145  while (samp_offset < data.Dim()) {
146  int32 samp_remaining = data.Dim() - samp_offset;
147  int32 num_samp = chunk_length < samp_remaining ? chunk_length
148  : samp_remaining;
149 
150  SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
151  feature_pipeline.AcceptWaveform(samp_freq, wave_part);
152 
153  samp_offset += num_samp;
154  if (samp_offset == data.Dim()) {
155  // no more input. flush out last frames
156  feature_pipeline.InputFinished();
157  }
158  }
159 
160  int32 feats_num_frames = feature_pipeline.NumFramesReady(),
161  feats_dim = feature_pipeline.Dim();
162  Matrix<BaseFloat> feats(feats_num_frames, feats_dim);
163 
164  for (int32 i = 0; i < feats_num_frames; i++) {
165  SubVector<BaseFloat> frame_vector(feats, i);
166  feature_pipeline.GetFrame(i, &frame_vector);
167  }
168 
169  // In an application you might avoid updating the adaptation state if
170  // you felt the utterance had low confidence. See lat/confidence.h
171  feature_pipeline.GetAdaptationState(&adaptation_state);
172  feature_pipeline.GetCmvnState(&cmvn_state);
173 
174  int32 output_frames = feats.NumRows(),
175  output_dim = nnet.OutputDim();
176  CuMatrix<BaseFloat> output(output_frames, output_dim),
177  feats_cu(feats);
178 
179  if (!pad_input)
180  output_frames -= nnet.LeftContext() + nnet.RightContext();
181  if (output_frames <= 0) {
182  KALDI_WARN << "Skipping utterance " << utt << " because output "
183  << "would be empty.";
184  continue;
185  }
186 
187  NnetComputation(nnet, feats_cu, pad_input, &output);
188 
189  if (apply_log) {
190  output.ApplyFloor(1.0e-20);
191  output.ApplyLog();
192  }
193 
194  writer.Write(utt, output);
195  num_frames += feats.NumRows();
196  num_done++;
197 
198  KALDI_LOG << "Processed data for utterance " << utt;
199  }
200  }
201 
202  KALDI_LOG << "Processed " << num_done << " feature files, "
203  << num_frames << " frames of input were processed.";
204 
205  return (num_done != 0 ? 0 : 1);
206  } catch(const std::exception& e) {
207  std::cerr << e.what() << '\n';
208  return -1;
209  }
210 } // main()
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 LeftContext() const
Returns the left-context summed over all the Components...
Definition: nnet-nnet.cc:42
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
int32 OutputDim() const
The output dimension of the network – typically the number of pdfs.
Definition: nnet-nnet.cc:31
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
BaseFloat SampFreq() const
Definition: wave-reader.h:126
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
int32 RightContext() const
Returns the right-context summed over all the Components...
Definition: nnet-nnet.cc:56
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
const Nnet & GetNnet() const
Definition: am-nnet.h:61