nnet-combine-egs-discriminative.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-combine-egs-discriminative.cc
2 
3 // Copyright 2012-2013 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
24 
25 int main(int argc, char *argv[]) {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet2;
29  typedef kaldi::int32 int32;
30  typedef kaldi::int64 int64;
31 
32  const char *usage =
33  "Copy examples for discriminative neural network training,\n"
34  "and combine successive examples if their combined length will\n"
35  "be less than --max-length. This can help to improve efficiency\n"
36  "(--max-length corresponds to minibatch size)\n"
37  "\n"
38  "Usage: nnet-combine-egs-discriminative [options] <egs-rspecifier> <egs-wspecifier>\n"
39  "\n"
40  "e.g.\n"
41  "nnet-combine-egs-discriminative --max-length=512 ark:temp.1.degs ark:1.degs\n";
42 
43  int32 max_length = 512;
44  int32 hard_max_length = 2048;
45  int32 batch_size = 250;
46  ParseOptions po(usage);
47  po.Register("max-length", &max_length, "Maximum length of example that we "
48  "will create when combining");
49  po.Register("batch-size", &batch_size, "Size of batch used when combinging "
50  "examples");
51  po.Register("hard-max-length", &hard_max_length, "Length of example beyond "
52  "which we will discard (very long examples may cause out of "
53  "memory errors)");
54 
55  po.Read(argc, argv);
56 
57  if (po.NumArgs() != 2) {
58  po.PrintUsage();
59  exit(1);
60  }
61 
62  KALDI_ASSERT(hard_max_length >= max_length);
63  KALDI_ASSERT(batch_size >= 1);
64 
65  std::string examples_rspecifier = po.GetArg(1),
66  examples_wspecifier = po.GetArg(2);
67 
69  examples_rspecifier);
70  DiscriminativeNnetExampleWriter example_writer(
71  examples_wspecifier);
72 
73  int64 num_read = 0, num_written = 0, num_discarded = 0;
74 
75  while (!example_reader.Done()) {
76  std::vector<DiscriminativeNnetExample> buffer;
77  size_t size = batch_size;
78  buffer.reserve(size);
79 
80  for (; !example_reader.Done() && buffer.size() < size;
81  example_reader.Next()) {
82  buffer.push_back(example_reader.Value());
83  num_read++;
84  }
85 
86  std::vector<DiscriminativeNnetExample> combined;
87  CombineDiscriminativeExamples(max_length, buffer, &combined);
88  buffer.clear();
89  for (size_t i = 0; i < combined.size(); i++) {
90  const DiscriminativeNnetExample &eg = combined[i];
91  int32 num_frames = eg.input_frames.NumRows();
92  if (num_frames > hard_max_length) {
93  KALDI_WARN << "Discarding segment of length " << num_frames
94  << " because it exceeds --hard-max-length="
95  << hard_max_length;
96  num_discarded++;
97  } else {
98  std::ostringstream ostr;
99  ostr << (num_written++);
100  example_writer.Write(ostr.str(), eg);
101  }
102  }
103  }
104 
105  KALDI_LOG << "Read " << num_read << " discriminative neural-network training"
106  << " examples, wrote " << num_written << ", discarded "
107  << num_discarded;
108  return (num_written == 0 ? 1 : 0);
109  } catch(const std::exception &e) {
110  std::cerr << e.what() << '\n';
111  return -1;
112  }
113 }
114 
115 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void CombineDiscriminativeExamples(int32 max_length, const std::vector< DiscriminativeNnetExample > &input, std::vector< DiscriminativeNnetExample > *output)
This function is used to combine multiple discriminative-training examples (each corresponding to a s...
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
int NumArgs() const
Number of positional parameters (c.f. argc-1).
This struct is used to store the information we need for discriminative training (MMI or MPE)...
Definition: nnet-example.h:136
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_LOG
Definition: kaldi-error.h:153
Note on how to parse this filename: it contains functions relatied to neural-net training examples...