nnet3-discriminative-get-egs.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-discriminative-get-egs.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <sstream>
22 
23 #include "base/kaldi-common.h"
24 #include "util/common-utils.h"
25 #include "hmm/transition-model.h"
26 #include "hmm/posterior.h"
30 #include "chain/chain-supervision.h"
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
35 // This function does all the processing for one utterance, and outputs the
36 // examples to 'example_writer'.
37 // returns true if we got as far as calling GetChunksForUtterance()
38 // [in which case stats will be accumulated by class UtteranceSplitter]
40  const TransitionModel &tmodel,
41  const MatrixBase<BaseFloat> &feats,
42  const MatrixBase<BaseFloat> *ivector_feats,
43  int32 ivector_period,
45  const std::string &utt_id,
46  bool compress,
47  UtteranceSplitter *utt_splitter,
48  NnetDiscriminativeExampleWriter *example_writer) {
49  KALDI_ASSERT(supervision.num_sequences == 1);
50  int32 num_input_frames = feats.NumRows(),
51  num_output_frames = supervision.frames_per_sequence;
52 
53  if (!utt_splitter->LengthsMatch(utt_id, num_input_frames, num_output_frames))
54  return false; // LengthsMatch() will have printed a warning.
55 
56  std::vector<ChunkTimeInfo> chunks;
57 
58  utt_splitter->GetChunksForUtterance(num_input_frames, &chunks);
59 
60  if (chunks.empty()) {
61  KALDI_WARN << "Not producing egs for utterance " << utt_id
62  << " because it is too short: "
63  << num_input_frames << " frames.";
64  }
65 
66  int32 frame_subsampling_factor = utt_splitter->Config().frame_subsampling_factor;
67 
69  supervision);
70 
71  for (size_t c = 0; c < chunks.size(); c++) {
72  ChunkTimeInfo &chunk = chunks[c];
73 
74  NnetDiscriminativeExample nnet_discriminative_eg;
75  nnet_discriminative_eg.outputs.resize(1);
76 
77  int32 start_frame_subsampled = chunk.first_frame / frame_subsampling_factor,
78  num_frames_subsampled = chunk.num_frames / frame_subsampling_factor;
79 
81 
82  splitter.GetFrameRange(start_frame_subsampled,
83  num_frames_subsampled,
84  (c == 0 ? false : true),
85  &supervision_part);
86 
87  SubVector<BaseFloat> output_weights(
88  &(chunk.output_weights[0]),
89  static_cast<int32>(chunk.output_weights.size()));
90 
91  int32 first_frame = 0; // we shift the time-indexes of all these parts so
92  // that the supervised part starts from frame 0.
93  NnetDiscriminativeSupervision nnet_supervision("output", supervision_part,
94  output_weights,
95  first_frame,
96  frame_subsampling_factor);
97  nnet_discriminative_eg.outputs[0].Swap(&nnet_supervision);
98 
99  nnet_discriminative_eg.inputs.resize(ivector_feats != NULL ? 2 : 1);
100 
101 
102  int32 tot_input_frames = chunk.left_context + chunk.num_frames +
103  chunk.right_context;
104 
105  Matrix<BaseFloat> input_frames(tot_input_frames, feats.NumCols(),
106  kUndefined);
107 
108  int32 start_frame = chunk.first_frame - chunk.left_context;
109  for (int32 t = start_frame; t < start_frame + tot_input_frames; t++) {
110  int32 t2 = t;
111  if (t2 < 0) t2 = 0;
112  if (t2 >= num_input_frames) t2 = num_input_frames - 1;
113  int32 j = t - start_frame;
114  SubVector<BaseFloat> src(feats, t2),
115  dest(input_frames, j);
116  dest.CopyFromVec(src);
117  }
118 
119  NnetIo input_io("input", -chunk.left_context, input_frames);
120  nnet_discriminative_eg.inputs[0].Swap(&input_io);
121 
122  if (ivector_feats != NULL) {
123  // if applicable, add the iVector feature.
124  // choose iVector from a random frame in the chunk
125  int32 ivector_frame = RandInt(start_frame,
126  start_frame + num_input_frames - 1),
127  ivector_frame_subsampled = ivector_frame / ivector_period;
128  if (ivector_frame_subsampled < 0)
129  ivector_frame_subsampled = 0;
130  if (ivector_frame_subsampled >= ivector_feats->NumRows())
131  ivector_frame_subsampled = ivector_feats->NumRows() - 1;
132  Matrix<BaseFloat> ivector(1, ivector_feats->NumCols());
133  ivector.Row(0).CopyFromVec(ivector_feats->Row(ivector_frame_subsampled));
134  NnetIo ivector_io("ivector", 0, ivector);
135  nnet_discriminative_eg.inputs[1].Swap(&ivector_io);
136  }
137 
138  if (compress)
139  nnet_discriminative_eg.Compress();
140 
141  std::ostringstream os;
142  os << utt_id << "-" << chunk.first_frame;
143 
144  std::string key = os.str(); // key is <utt_id>-<frame_id>
145 
146  example_writer->Write(key, nnet_discriminative_eg);
147  }
148  return true;
149 }
150 
151 
152 } // namespace nnet3
153 } // namespace kaldi
154 
155 int main(int argc, char *argv[]) {
156  try {
157  using namespace kaldi;
158  using namespace kaldi::nnet3;
159  typedef kaldi::int32 int32;
160  typedef kaldi::int64 int64;
161 
162  const char *usage =
163  "Get frame-by-frame examples of data for nnet3+sequence neural network\n"
164  "training. This involves breaking up utterances into pieces of sizes\n"
165  "determined by the --num-frames option.\n"
166  "\n"
167  "Usage: nnet3-discriminative-get-egs [options] <model> <features-rspecifier> "
168  "<denominator-lattice-rspecifier> <numerator-alignment-rspecifier> <egs-wspecifier>\n"
169  "\n"
170  "An example [where $feats expands to the actual features]:\n"
171  " nnet3-discriminative-get-egs --left-context=25 --right-context=9 --num-frames=150,100,90 \\\n"
172  " \"$feats\" \"ark,s,cs:gunzip -c lat.1.gz\" scp:ali.scp ark:degs.1.ark\n";
173 
174  bool compress = true;
175  int32 length_tolerance = 100, online_ivector_period = 1;
176 
177  std::string online_ivector_rspecifier;
178 
179  ExampleGenerationConfig eg_config; // controls num-frames,
180  // left/right-context, etc.
182 
183  ParseOptions po(usage);
184 
185  eg_config.Register(&po);
186  po.Register("compress", &compress, "If true, write egs in "
187  "compressed format (recommended)");
188  po.Register("ivectors", &online_ivector_rspecifier, "Alias for --online-ivectors "
189  "option, for back compatibility");
190  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier of ivector "
191  "features, as a matrix.");
192  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
193  "between iVectors in matrices supplied to the --online-ivectors "
194  "option");
195  po.Register("length-tolerance", &length_tolerance, "Tolerance for "
196  "difference in num-frames between feat and ivector matrices");
197 
198  splitter_config.Register(&po);
199 
200  po.Read(argc, argv);
201 
202  if (po.NumArgs() != 5) {
203  po.PrintUsage();
204  exit(1);
205  }
206 
207  eg_config.ComputeDerived();
208  UtteranceSplitter utt_splitter(eg_config);
209 
210  std::string model_wxfilename = po.GetArg(1),
211  feature_rspecifier = po.GetArg(2),
212  den_lat_rspecifier = po.GetArg(3),
213  num_ali_rspecifier = po.GetArg(4),
214  examples_wspecifier = po.GetArg(5);
215 
216 
217  TransitionModel tmodel;
218  {
219  bool binary;
220  Input ki(model_wxfilename, &binary);
221  tmodel.Read(ki.Stream(), binary);
222  }
223 
224  SequentialBaseFloatMatrixReader feat_reader(feature_rspecifier);
225  RandomAccessLatticeReader den_lat_reader(den_lat_rspecifier);
226  RandomAccessInt32VectorReader ali_reader(num_ali_rspecifier);
227  NnetDiscriminativeExampleWriter example_writer(examples_wspecifier);
228  RandomAccessBaseFloatMatrixReader online_ivector_reader(
229  online_ivector_rspecifier);
230 
231  int32 num_err = 0;
232 
233  for (; !feat_reader.Done(); feat_reader.Next()) {
234  std::string key = feat_reader.Key();
235  const Matrix<BaseFloat> &feats = feat_reader.Value();
236  if (!den_lat_reader.HasKey(key)) {
237  KALDI_WARN << "No denominator lattice for key " << key;
238  num_err++;
239  } else if (!ali_reader.HasKey(key)) {
240  KALDI_WARN << "No numerator alignment for key " << key;
241  num_err++;
242  } else {
244  if (!supervision.Initialize(ali_reader.Value(key),
245  den_lat_reader.Value(key),
246  1.0)) {
247  KALDI_WARN << "Failed to convert lattice to supervision "
248  << "for utterance " << key;
249  num_err++;
250  continue;
251  }
252 
253  const Matrix<BaseFloat> *online_ivector_feats = NULL;
254  if (!online_ivector_rspecifier.empty()) {
255  if (!online_ivector_reader.HasKey(key)) {
256  KALDI_WARN << "No iVectors for utterance " << key;
257  num_err++;
258  continue;
259  } else {
260  // this address will be valid until we call HasKey() or Value()
261  // again.
262  online_ivector_feats = &(online_ivector_reader.Value(key));
263  }
264  }
265  if (online_ivector_feats != NULL &&
266  (abs(feats.NumRows() - (online_ivector_feats->NumRows() *
267  online_ivector_period)) > length_tolerance
268  || online_ivector_feats->NumRows() == 0)) {
269  KALDI_WARN << "Length difference between feats " << feats.NumRows()
270  << " and iVectors " << online_ivector_feats->NumRows()
271  << "exceeds tolerance " << length_tolerance;
272  num_err++;
273  continue;
274  }
275  if (!ProcessFile(splitter_config, tmodel,
276  feats, online_ivector_feats, online_ivector_period,
277  supervision, key, compress,
278  &utt_splitter, &example_writer))
279  num_err++;
280  }
281  }
282  if (num_err > 0)
283  KALDI_WARN << num_err << " utterances had errors and could "
284  "not be processed.";
285  // utt_splitter prints diagnostics.
286  return utt_splitter.ExitStatus();
287  } catch(const std::exception &e) {
288  std::cerr << e.what() << '\n';
289  return -1;
290  }
291 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void GetFrameRange(int32 begin_frame, int32 frames_per_sequence, bool normalize, DiscriminativeSupervision *supervision) const
bool LengthsMatch(const std::string &utt, int32 utterance_length, int32 supervision_length, int32 length_tolerance=0) const
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
static bool ProcessFile(const discriminative::SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const MatrixBase< BaseFloat > &feats, const MatrixBase< BaseFloat > *ivector_feats, int32 ivector_period, const discriminative::DiscriminativeSupervision &supervision, const std::string &utt_id, bool compress, UtteranceSplitter *utt_splitter, NnetDiscriminativeExampleWriter *example_writer)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
const ExampleGenerationConfig & Config() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
bool Initialize(const std::vector< int32 > &alignment, const Lattice &lat, BaseFloat weight)
std::vector< NnetDiscriminativeSupervision > outputs
&#39;outputs&#39; contains the sequence output supervision.
void GetChunksForUtterance(int32 utterance_length, std::vector< ChunkTimeInfo > *chunk_info)
struct ChunkTimeInfo is used by class UtteranceSplitter to output information about how we split an u...
std::vector< BaseFloat > output_weights
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void ComputeDerived()
This function decodes &#39;num_frames_str&#39; into &#39;num_frames&#39;, and ensures that the members of &#39;num_frames...