nnet-combine-egs-discriminative.cc File Reference
Include dependency graph for nnet-combine-egs-discriminative.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 25 of file nnet-combine-egs-discriminative.cc.

References kaldi::nnet2::CombineDiscriminativeExamples(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), rnnlm::i, DiscriminativeNnetExample::input_frames, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), ParseOptions::Register(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

25  {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet2;
29  typedef kaldi::int32 int32;
30  typedef kaldi::int64 int64;
31 
32  const char *usage =
33  "Copy examples for discriminative neural network training,\n"
34  "and combine successive examples if their combined length will\n"
35  "be less than --max-length. This can help to improve efficiency\n"
36  "(--max-length corresponds to minibatch size)\n"
37  "\n"
38  "Usage: nnet-combine-egs-discriminative [options] <egs-rspecifier> <egs-wspecifier>\n"
39  "\n"
40  "e.g.\n"
41  "nnet-combine-egs-discriminative --max-length=512 ark:temp.1.degs ark:1.degs\n";
42 
43  int32 max_length = 512;
44  int32 hard_max_length = 2048;
45  int32 batch_size = 250;
46  ParseOptions po(usage);
47  po.Register("max-length", &max_length, "Maximum length of example that we "
48  "will create when combining");
49  po.Register("batch-size", &batch_size, "Size of batch used when combinging "
50  "examples");
51  po.Register("hard-max-length", &hard_max_length, "Length of example beyond "
52  "which we will discard (very long examples may cause out of "
53  "memory errors)");
54 
55  po.Read(argc, argv);
56 
57  if (po.NumArgs() != 2) {
58  po.PrintUsage();
59  exit(1);
60  }
61 
62  KALDI_ASSERT(hard_max_length >= max_length);
63  KALDI_ASSERT(batch_size >= 1);
64 
65  std::string examples_rspecifier = po.GetArg(1),
66  examples_wspecifier = po.GetArg(2);
67 
69  examples_rspecifier);
70  DiscriminativeNnetExampleWriter example_writer(
71  examples_wspecifier);
72 
73  int64 num_read = 0, num_written = 0, num_discarded = 0;
74 
75  while (!example_reader.Done()) {
76  std::vector<DiscriminativeNnetExample> buffer;
77  size_t size = batch_size;
78  buffer.reserve(size);
79 
80  for (; !example_reader.Done() && buffer.size() < size;
81  example_reader.Next()) {
82  buffer.push_back(example_reader.Value());
83  num_read++;
84  }
85 
86  std::vector<DiscriminativeNnetExample> combined;
87  CombineDiscriminativeExamples(max_length, buffer, &combined);
88  buffer.clear();
89  for (size_t i = 0; i < combined.size(); i++) {
90  const DiscriminativeNnetExample &eg = combined[i];
91  int32 num_frames = eg.input_frames.NumRows();
92  if (num_frames > hard_max_length) {
93  KALDI_WARN << "Discarding segment of length " << num_frames
94  << " because it exceeds --hard-max-length="
95  << hard_max_length;
96  num_discarded++;
97  } else {
98  std::ostringstream ostr;
99  ostr << (num_written++);
100  example_writer.Write(ostr.str(), eg);
101  }
102  }
103  }
104 
105  KALDI_LOG << "Read " << num_read << " discriminative neural-network training"
106  << " examples, wrote " << num_written << ", discarded "
107  << num_discarded;
108  return (num_written == 0 ? 1 : 0);
109  } catch(const std::exception &e) {
110  std::cerr << e.what() << '\n';
111  return -1;
112  }
113 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void CombineDiscriminativeExamples(int32 max_length, const std::vector< DiscriminativeNnetExample > &input, std::vector< DiscriminativeNnetExample > *output)
This function is used to combine multiple discriminative-training examples (each corresponding to a s...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
This struct is used to store the information we need for discriminative training (MMI or MPE)...
Definition: nnet-example.h:136
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_LOG
Definition: kaldi-error.h:153