nnet3-copy-egs.cc File Reference
Include dependency graph for nnet3-copy-egs.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet3
 

Functions

void RenameOutputs (const std::string &new_name, NnetExample *eg)
 
void ScaleSupervisionWeight (BaseFloat weight, NnetExample *eg)
 
int32 GetCount (double expected_count)
 
bool ContainsSingleExample (const NnetExample &eg, int32 *min_input_t, int32 *max_input_t, int32 *min_output_t, int32 *max_output_t)
 Returns true if the "eg" contains just a single example, meaning that all the "n" values in the indexes are zero, and the example has NnetIo members named both "input" and "output". More...
 
void FilterExample (const NnetExample &eg, int32 min_input_t, int32 max_input_t, int32 min_output_t, int32 max_output_t, NnetExample *eg_out)
 This function filters the indexes (and associated feature rows) in a NnetExample, removing any index/row in an NnetIo named "input" with t < min_input_t or t > max_input_t and any index/row in an NnetIo named "output" with t < min_output_t or t > max_output_t. More...
 
bool SelectFromExample (const NnetExample &eg, std::string frame_str, int32 left_context, int32 right_context, int32 frame_shift, NnetExample *eg_out)
 This function is responsible for possibly selecting one frame from multiple supervised frames, and reducing the left and right context as specified. More...
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 290 of file nnet3-copy-egs.cc.

References count, SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), kaldi::nnet3::GetCount(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), kaldi::Rand(), ParseOptions::Read(), ParseOptions::Register(), kaldi::nnet3::RenameOutputs(), kaldi::nnet3::ScaleSupervisionWeight(), kaldi::nnet3::SelectFromExample(), RandomAccessTableReader< Holder >::Value(), and SequentialTableReader< Holder >::Value().

290  {
291  try {
292  using namespace kaldi;
293  using namespace kaldi::nnet3;
294  typedef kaldi::int32 int32;
295  typedef kaldi::int64 int64;
296 
297  const char *usage =
298  "Copy examples (single frames or fixed-size groups of frames) for neural\n"
299  "network training, possibly changing the binary mode. Supports multiple wspecifiers, in\n"
300  "which case it will write the examples round-robin to the outputs.\n"
301  "\n"
302  "Usage: nnet3-copy-egs [options] <egs-rspecifier> <egs-wspecifier1> [<egs-wspecifier2> ...]\n"
303  "\n"
304  "e.g.\n"
305  "nnet3-copy-egs ark:train.egs ark,t:text.egs\n"
306  "or:\n"
307  "nnet3-copy-egs ark:train.egs ark:1.egs ark:2.egs\n"
308  "See also: nnet3-subset-egs, nnet3-get-egs, nnet3-merge-egs, nnet3-shuffle-egs\n";
309 
310  bool random = false;
311  int32 srand_seed = 0;
312  int32 frame_shift = 0;
313  BaseFloat keep_proportion = 1.0;
314 
315  // The following config variables, if set, can be used to extract a single
316  // frame of labels from a multi-frame example, and/or to reduce the amount
317  // of context.
318  int32 left_context = -1, right_context = -1;
319 
320  // you can set frame to a number to select a single frame with a particular
321  // offset, or to 'random' to select a random single frame.
322  std::string frame_str,
323  eg_weight_rspecifier, eg_output_name_rspecifier;
324 
325  ParseOptions po(usage);
326  po.Register("random", &random, "If true, will write frames to output "
327  "archives randomly, not round-robin.");
328  po.Register("frame-shift", &frame_shift, "Allows you to shift time values "
329  "in the supervision data (excluding iVector data). Only really "
330  "useful in clockwork topologies (i.e. any topology for which "
331  "modulus != 1). Shifting is done after any frame selection.");
332  po.Register("keep-proportion", &keep_proportion, "If <1.0, this program will "
333  "randomly keep this proportion of the input samples. If >1.0, it will "
334  "in expectation copy a sample this many times. It will copy it a number "
335  "of times equal to floor(keep-proportion) or ceil(keep-proportion).");
336  po.Register("srand", &srand_seed, "Seed for random number generator "
337  "(only relevant if --random=true or --keep-proportion != 1.0)");
338  po.Register("frame", &frame_str, "This option can be used to select a single "
339  "frame from each multi-frame example. Set to a number 0, 1, etc. "
340  "to select a frame with a given index, or 'random' to select a "
341  "random frame.");
342  po.Register("left-context", &left_context, "Can be used to truncate the "
343  "feature left-context that we output.");
344  po.Register("right-context", &right_context, "Can be used to truncate the "
345  "feature right-context that we output.");
346  po.Register("weights", &eg_weight_rspecifier,
347  "Rspecifier indexed by the key of egs, providing a weight by "
348  "which we will scale the supervision matrix for that eg. "
349  "Used in multilingual training.");
350  po.Register("outputs", &eg_output_name_rspecifier,
351  "Rspecifier indexed by the key of egs, providing a string-valued "
352  "output name, e.g. 'output-0'. If provided, the NnetIo with "
353  "name 'output' will be renamed to the provided name. Used in "
354  "multilingual training.");
355  po.Read(argc, argv);
356 
357  srand(srand_seed);
358 
359  if (po.NumArgs() < 2) {
360  po.PrintUsage();
361  exit(1);
362  }
363 
364  std::string examples_rspecifier = po.GetArg(1);
365 
366  SequentialNnetExampleReader example_reader(examples_rspecifier);
367 
368  // In the normal case, these would not be used. These are only applicable
369  // for multi-task or multilingual training.
370  RandomAccessTokenReader output_name_reader(eg_output_name_rspecifier);
371  RandomAccessBaseFloatReader egs_weight_reader(eg_weight_rspecifier);
372 
373  int32 num_outputs = po.NumArgs() - 1;
374  std::vector<NnetExampleWriter*> example_writers(num_outputs);
375  for (int32 i = 0; i < num_outputs; i++)
376  example_writers[i] = new NnetExampleWriter(po.GetArg(i+2));
377 
378 
379  int64 num_read = 0, num_written = 0, num_err = 0;
380  for (; !example_reader.Done(); example_reader.Next(), num_read++) {
381  const std::string &key = example_reader.Key();
382  NnetExample &eg = example_reader.Value();
383  // count is normally 1; could be 0, or possibly >1.
384  int32 count = GetCount(keep_proportion);
385 
386  if (!eg_weight_rspecifier.empty()) {
387  BaseFloat weight = 1.0;
388  if (!egs_weight_reader.HasKey(key)) {
389  KALDI_WARN << "No weight for example key " << key;
390  num_err++;
391  continue;
392  }
393  weight = egs_weight_reader.Value(key);
394  ScaleSupervisionWeight(weight, &eg);
395  }
396 
397  std::string new_output_name;
398  if (!eg_output_name_rspecifier.empty()) {
399  if (!output_name_reader.HasKey(key)) {
400  KALDI_WARN << "No new output-name for example key " << key;
401  num_err++;
402  continue;
403  }
404  new_output_name = output_name_reader.Value(key);
405  }
406  for (int32 c = 0; c < count; c++) {
407  int32 index = (random ? Rand() : num_written) % num_outputs;
408  if (frame_str == "" && left_context == -1 && right_context == -1 &&
409  frame_shift == 0) {
410  if (!new_output_name.empty() && c == 0)
411  RenameOutputs(new_output_name, &eg);
412  example_writers[index]->Write(key, eg);
413  num_written++;
414  } else { // the --frame option or context options were set.
415  NnetExample eg_modified;
416  if (SelectFromExample(eg, frame_str, left_context, right_context,
417  frame_shift, &eg_modified)) {
418  if (!new_output_name.empty())
419  RenameOutputs(new_output_name, &eg_modified);
420  // this branch of the if statement will almost always be taken (should only
421  // not be taken for shorter-than-normal egs from the end of a file.
422  example_writers[index]->Write(key, eg_modified);
423  num_written++;
424  }
425  }
426  }
427  }
428 
429  for (int32 i = 0; i < num_outputs; i++)
430  delete example_writers[i];
431  KALDI_LOG << "Read " << num_read << " neural-network training examples, wrote "
432  << num_written << ", "
433  << num_err << " examples had errors.";
434  return (num_written == 0 ? 1 : 0);
435  } catch(const std::exception &e) {
436  std::cerr << e.what() << '\n';
437  return -1;
438  }
439 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void ScaleSupervisionWeight(BaseFloat weight, NnetExample *eg)
#define KALDI_WARN
Definition: kaldi-error.h:150
void RenameOutputs(const std::string &new_name, NnetExample *eg)
int32 GetCount(double expected_count)
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
TableWriter< KaldiObjectHolder< NnetExample > > NnetExampleWriter
Definition: nnet-example.h:170
bool SelectFromExample(const NnetExample &eg, std::string frame_str, int32 left_context, int32 right_context, int32 frame_shift, NnetExample *eg_out)
This function is responsible for possibly selecting one frame from multiple supervised frames...
#define KALDI_LOG
Definition: kaldi-error.h:153