nnet3-copy-egs.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-copy-egs.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "hmm/transition-model.h"
24 #include "nnet3/nnet-example.h"
26 
27 namespace kaldi {
28 namespace nnet3 {
29 
30 // renames outputs named "output" to new_name
31 void RenameOutputs(const std::string &new_name, NnetExample *eg) {
32  bool found_output = false;
33  for (std::vector<NnetIo>::iterator it = eg->io.begin();
34  it != eg->io.end(); ++it) {
35  if (it->name == "output") {
36  it->name = new_name;
37  found_output = true;
38  }
39  }
40 
41  if (!found_output)
42  KALDI_ERR << "No io-node with name 'output'"
43  << "exists in eg.";
44 }
45 
46 // scales the supervision for 'output' by a factor of "weight"
48  if (weight == 1.0) return;
49 
50  bool found_output = false;
51  for (std::vector<NnetIo>::iterator it = eg->io.begin();
52  it != eg->io.end(); ++it) {
53  if (it->name == "output") {
54  it->features.Scale(weight);
55  found_output = true;
56  }
57  }
58 
59  if (!found_output)
60  KALDI_ERR << "No supervision with name 'output'"
61  << "exists in eg.";
62 }
63 
64 // returns an integer randomly drawn with expected value "expected_count"
65 // (will be either floor(expected_count) or ceil(expected_count)).
66 int32 GetCount(double expected_count) {
67  KALDI_ASSERT(expected_count >= 0.0);
68  int32 ans = floor(expected_count);
69  expected_count -= ans;
70  if (WithProb(expected_count))
71  ans++;
72  return ans;
73 }
74 
83  int32 *min_input_t,
84  int32 *max_input_t,
85  int32 *min_output_t,
86  int32 *max_output_t) {
87  bool done_input = false, done_output = false;
88  int32 num_indexes = eg.io.size();
89  for (int32 i = 0; i < num_indexes; i++) {
90  const NnetIo &io = eg.io[i];
91  std::vector<Index>::const_iterator iter = io.indexes.begin(),
92  end = io.indexes.end();
93  // Should not have an empty input/output type.
94  KALDI_ASSERT(!io.indexes.empty());
95  if (io.name == "input" || io.name == "output") {
96  int32 min_t = iter->t, max_t = iter->t;
97  for (; iter != end; ++iter) {
98  int32 this_t = iter->t;
99  min_t = std::min(min_t, this_t);
100  max_t = std::max(max_t, this_t);
101  if (iter->n != 0) {
102  KALDI_WARN << "Example does not contain just a single example; "
103  << "too late to do frame selection or reduce context.";
104  return false;
105  }
106  }
107  if (io.name == "input") {
108  done_input = true;
109  *min_input_t = min_t;
110  *max_input_t = max_t;
111  } else {
112  KALDI_ASSERT(io.name == "output");
113  done_output = true;
114  *min_output_t = min_t;
115  *max_output_t = max_t;
116  }
117  } else {
118  for (; iter != end; ++iter) {
119  if (iter->n != 0) {
120  KALDI_WARN << "Example does not contain just a single example; "
121  << "too late to do frame selection or reduce context.";
122  return false;
123  }
124  }
125  }
126  }
127  if (!done_input) {
128  KALDI_WARN << "Example does not have any input named 'input'";
129  return false;
130  }
131  if (!done_output) {
132  KALDI_WARN << "Example does not have any output named 'output'";
133  return false;
134  }
135  return true;
136 }
137 
145 void FilterExample(const NnetExample &eg,
146  int32 min_input_t,
147  int32 max_input_t,
148  int32 min_output_t,
149  int32 max_output_t,
150  NnetExample *eg_out) {
151  eg_out->io.clear();
152  eg_out->io.resize(eg.io.size());
153  for (size_t i = 0; i < eg.io.size(); i++) {
154  bool is_input_or_output;
155  int32 min_t, max_t;
156  const NnetIo &io_in = eg.io[i];
157  NnetIo &io_out = eg_out->io[i];
158  const std::string &name = io_in.name;
159  io_out.name = name;
160  if (name == "input") {
161  min_t = min_input_t;
162  max_t = max_input_t;
163  is_input_or_output = true;
164  } else if (name == "output") {
165  min_t = min_output_t;
166  max_t = max_output_t;
167  is_input_or_output = true;
168  } else {
169  is_input_or_output = false;
170  }
171  if (!is_input_or_output) { // Just copy everything.
172  io_out.indexes = io_in.indexes;
173  io_out.features = io_in.features;
174  } else {
175  const std::vector<Index> &indexes_in = io_in.indexes;
176  std::vector<Index> &indexes_out = io_out.indexes;
177  indexes_out.reserve(indexes_in.size());
178  int32 num_indexes = indexes_in.size(), num_kept = 0;
179  KALDI_ASSERT(io_in.features.NumRows() == num_indexes);
180  std::vector<bool> keep(num_indexes, false);
181  std::vector<Index>::const_iterator iter_in = indexes_in.begin(),
182  end_in = indexes_in.end();
183  std::vector<bool>::iterator iter_out = keep.begin();
184  for (; iter_in != end_in; ++iter_in,++iter_out) {
185  int32 t = iter_in->t;
186  bool is_within_range = (t >= min_t && t <= max_t);
187  *iter_out = is_within_range;
188  if (is_within_range) {
189  indexes_out.push_back(*iter_in);
190  num_kept++;
191  }
192  }
193  KALDI_ASSERT(iter_out == keep.end());
194  if (num_kept == 0)
195  KALDI_ERR << "FilterExample removed all indexes for '" << name << "'";
196 
197  FilterGeneralMatrixRows(io_in.features, keep,
198  &io_out.features);
199  KALDI_ASSERT(io_out.features.NumRows() == num_kept &&
200  indexes_out.size() == static_cast<size_t>(num_kept));
201  }
202  }
203 }
204 
205 
223  std::string frame_str,
224  int32 left_context,
225  int32 right_context,
226  int32 frame_shift,
227  NnetExample *eg_out) {
228  static bool warned_left = false, warned_right = false;
229  int32 min_input_t, max_input_t,
230  min_output_t, max_output_t;
231  if (!ContainsSingleExample(eg, &min_input_t, &max_input_t,
232  &min_output_t, &max_output_t))
233  KALDI_ERR << "Too late to perform frame selection/context reduction on "
234  << "these examples (already merged?)";
235  if (frame_str != "") {
236  // select one frame.
237  if (frame_str == "random") {
238  min_output_t = max_output_t = RandInt(min_output_t,
239  max_output_t);
240  } else {
241  int32 frame;
242  if (!ConvertStringToInteger(frame_str, &frame))
243  KALDI_ERR << "Invalid option --frame='" << frame_str << "'";
244  if (frame < min_output_t || frame > max_output_t) {
245  // Frame is out of range. Should happen only rarely. Calling code
246  // makes sure of this.
247  return false;
248  }
249  min_output_t = max_output_t = frame;
250  }
251  }
252  if (left_context != -1) {
253  if (!warned_left && min_input_t > min_output_t - left_context) {
254  warned_left = true;
255  KALDI_WARN << "You requested --left-context=" << left_context
256  << ", but example only has left-context of "
257  << (min_output_t - min_input_t)
258  << " (will warn only once; this may be harmless if "
259  "using any --*left-context-initial options)";
260  }
261  min_input_t = std::max(min_input_t, min_output_t - left_context);
262  }
263  if (right_context != -1) {
264  if (!warned_right && max_input_t < max_output_t + right_context) {
265  warned_right = true;
266  KALDI_WARN << "You requested --right-context=" << right_context
267  << ", but example only has right-context of "
268  << (max_input_t - max_output_t)
269  << " (will warn only once; this may be harmless if "
270  "using any --*right-context-final options.";
271  }
272  max_input_t = std::min(max_input_t, max_output_t + right_context);
273  }
274  FilterExample(eg,
275  min_input_t, max_input_t,
276  min_output_t, max_output_t,
277  eg_out);
278  if (frame_shift != 0) {
279  std::vector<std::string> exclude_names; // we can later make this
280  exclude_names.push_back(std::string("ivector")); // configurable.
281  ShiftExampleTimes(frame_shift, exclude_names, eg_out);
282  }
283  return true;
284 }
285 
286 
287 } // namespace nnet3
288 } // namespace kaldi
289 
290 int main(int argc, char *argv[]) {
291  try {
292  using namespace kaldi;
293  using namespace kaldi::nnet3;
294  typedef kaldi::int32 int32;
295  typedef kaldi::int64 int64;
296 
297  const char *usage =
298  "Copy examples (single frames or fixed-size groups of frames) for neural\n"
299  "network training, possibly changing the binary mode. Supports multiple wspecifiers, in\n"
300  "which case it will write the examples round-robin to the outputs.\n"
301  "\n"
302  "Usage: nnet3-copy-egs [options] <egs-rspecifier> <egs-wspecifier1> [<egs-wspecifier2> ...]\n"
303  "\n"
304  "e.g.\n"
305  "nnet3-copy-egs ark:train.egs ark,t:text.egs\n"
306  "or:\n"
307  "nnet3-copy-egs ark:train.egs ark:1.egs ark:2.egs\n"
308  "See also: nnet3-subset-egs, nnet3-get-egs, nnet3-merge-egs, nnet3-shuffle-egs\n";
309 
310  bool random = false;
311  int32 srand_seed = 0;
312  int32 frame_shift = 0;
313  BaseFloat keep_proportion = 1.0;
314 
315  // The following config variables, if set, can be used to extract a single
316  // frame of labels from a multi-frame example, and/or to reduce the amount
317  // of context.
318  int32 left_context = -1, right_context = -1;
319 
320  // you can set frame to a number to select a single frame with a particular
321  // offset, or to 'random' to select a random single frame.
322  std::string frame_str,
323  eg_weight_rspecifier, eg_output_name_rspecifier;
324 
325  ParseOptions po(usage);
326  po.Register("random", &random, "If true, will write frames to output "
327  "archives randomly, not round-robin.");
328  po.Register("frame-shift", &frame_shift, "Allows you to shift time values "
329  "in the supervision data (excluding iVector data). Only really "
330  "useful in clockwork topologies (i.e. any topology for which "
331  "modulus != 1). Shifting is done after any frame selection.");
332  po.Register("keep-proportion", &keep_proportion, "If <1.0, this program will "
333  "randomly keep this proportion of the input samples. If >1.0, it will "
334  "in expectation copy a sample this many times. It will copy it a number "
335  "of times equal to floor(keep-proportion) or ceil(keep-proportion).");
336  po.Register("srand", &srand_seed, "Seed for random number generator "
337  "(only relevant if --random=true or --keep-proportion != 1.0)");
338  po.Register("frame", &frame_str, "This option can be used to select a single "
339  "frame from each multi-frame example. Set to a number 0, 1, etc. "
340  "to select a frame with a given index, or 'random' to select a "
341  "random frame.");
342  po.Register("left-context", &left_context, "Can be used to truncate the "
343  "feature left-context that we output.");
344  po.Register("right-context", &right_context, "Can be used to truncate the "
345  "feature right-context that we output.");
346  po.Register("weights", &eg_weight_rspecifier,
347  "Rspecifier indexed by the key of egs, providing a weight by "
348  "which we will scale the supervision matrix for that eg. "
349  "Used in multilingual training.");
350  po.Register("outputs", &eg_output_name_rspecifier,
351  "Rspecifier indexed by the key of egs, providing a string-valued "
352  "output name, e.g. 'output-0'. If provided, the NnetIo with "
353  "name 'output' will be renamed to the provided name. Used in "
354  "multilingual training.");
355  po.Read(argc, argv);
356 
357  srand(srand_seed);
358 
359  if (po.NumArgs() < 2) {
360  po.PrintUsage();
361  exit(1);
362  }
363 
364  std::string examples_rspecifier = po.GetArg(1);
365 
366  SequentialNnetExampleReader example_reader(examples_rspecifier);
367 
368  // In the normal case, these would not be used. These are only applicable
369  // for multi-task or multilingual training.
370  RandomAccessTokenReader output_name_reader(eg_output_name_rspecifier);
371  RandomAccessBaseFloatReader egs_weight_reader(eg_weight_rspecifier);
372 
373  int32 num_outputs = po.NumArgs() - 1;
374  std::vector<NnetExampleWriter*> example_writers(num_outputs);
375  for (int32 i = 0; i < num_outputs; i++)
376  example_writers[i] = new NnetExampleWriter(po.GetArg(i+2));
377 
378 
379  int64 num_read = 0, num_written = 0, num_err = 0;
380  for (; !example_reader.Done(); example_reader.Next(), num_read++) {
381  const std::string &key = example_reader.Key();
382  NnetExample &eg = example_reader.Value();
383  // count is normally 1; could be 0, or possibly >1.
384  int32 count = GetCount(keep_proportion);
385 
386  if (!eg_weight_rspecifier.empty()) {
387  BaseFloat weight = 1.0;
388  if (!egs_weight_reader.HasKey(key)) {
389  KALDI_WARN << "No weight for example key " << key;
390  num_err++;
391  continue;
392  }
393  weight = egs_weight_reader.Value(key);
394  ScaleSupervisionWeight(weight, &eg);
395  }
396 
397  std::string new_output_name;
398  if (!eg_output_name_rspecifier.empty()) {
399  if (!output_name_reader.HasKey(key)) {
400  KALDI_WARN << "No new output-name for example key " << key;
401  num_err++;
402  continue;
403  }
404  new_output_name = output_name_reader.Value(key);
405  }
406  for (int32 c = 0; c < count; c++) {
407  int32 index = (random ? Rand() : num_written) % num_outputs;
408  if (frame_str == "" && left_context == -1 && right_context == -1 &&
409  frame_shift == 0) {
410  if (!new_output_name.empty() && c == 0)
411  RenameOutputs(new_output_name, &eg);
412  example_writers[index]->Write(key, eg);
413  num_written++;
414  } else { // the --frame option or context options were set.
415  NnetExample eg_modified;
416  if (SelectFromExample(eg, frame_str, left_context, right_context,
417  frame_shift, &eg_modified)) {
418  if (!new_output_name.empty())
419  RenameOutputs(new_output_name, &eg_modified);
420  // this branch of the if statement will almost always be taken (should only
421  // not be taken for shorter-than-normal egs from the end of a file.
422  example_writers[index]->Write(key, eg_modified);
423  num_written++;
424  }
425  }
426  }
427  }
428 
429  for (int32 i = 0; i < num_outputs; i++)
430  delete example_writers[i];
431  KALDI_LOG << "Read " << num_read << " neural-network training examples, wrote "
432  << num_written << ", "
433  << num_err << " examples had errors.";
434  return (num_written == 0 ? 1 : 0);
435  } catch(const std::exception &e) {
436  std::cerr << e.what() << '\n';
437  return -1;
438  }
439 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool ConvertStringToInteger(const std::string &str, Int *out)
Converts a string into an integer via strtoll and returns false if there was any kind of problem (i...
Definition: text-utils.h:118
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
bool WithProb(BaseFloat prob, struct RandomState *state)
Definition: kaldi-math.cc:72
kaldi::int32 int32
GeneralMatrix features
The features or labels.
Definition: nnet-example.h:46
void ShiftExampleTimes(int32 t_offset, const std::vector< std::string > &exclude_names, NnetExample *eg)
Shifts the time-index t of everything in the "eg" by adding "t_offset" to all "t" values...
std::vector< Index > indexes
"indexes" is a vector the same length as features.NumRows(), explaining the meaning of each row of th...
Definition: nnet-example.h:42
bool ContainsSingleExample(const NnetExample &eg, int32 *min_input_t, int32 *max_input_t, int32 *min_output_t, int32 *max_output_t)
Returns true if the "eg" contains just a single example, meaning that all the "n" values in the index...
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void FilterGeneralMatrixRows(const GeneralMatrix &in, const std::vector< bool > &keep_rows, GeneralMatrix *out)
Outputs a GeneralMatrix containing only the rows r of "in" such that keep_rows[r] == true...
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
int main(int argc, char *argv[])
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void ScaleSupervisionWeight(BaseFloat weight, NnetExample *eg)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
void RenameOutputs(const std::string &new_name, NnetExample *eg)
bool HasKey(const std::string &key)
int32 GetCount(double expected_count)
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
TableWriter< KaldiObjectHolder< NnetExample > > NnetExampleWriter
Definition: nnet-example.h:170
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
Definition: nnet-example.h:36
bool SelectFromExample(const NnetExample &eg, std::string frame_str, int32 left_context, int32 right_context, int32 frame_shift, NnetExample *eg_out)
This function is responsible for possibly selecting one frame from multiple supervised frames...
void FilterExample(const NnetExample &eg, int32 min_input_t, int32 max_input_t, int32 min_output_t, int32 max_output_t, NnetExample *eg_out)
This function filters the indexes (and associated feature rows) in a NnetExample, removing any index/...
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95