nnet-copy-egs.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-copy-egs.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 // Copyright 2014 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "hmm/transition-model.h"
25 
26 namespace kaldi {
27 namespace nnet2 {
28 // returns an integer randomly drawn with expected value "expected_count"
29 // (will be either floor(expected_count) or ceil(expected_count)).
30 // this will go into an infinite loop if expected_count is very huge, but
31 // it should never be that huge.
32 int32 GetCount(double expected_count) {
33  KALDI_ASSERT(expected_count >= 0.0);
34  int32 ans = 0;
35  while (expected_count > 1.0) {
36  ans++;
37  expected_count--;
38  }
39  if (WithProb(expected_count))
40  ans++;
41  return ans;
42 }
43 
44 } // namespace nnet2
45 } // namespace kaldi
46 
47 int main(int argc, char *argv[]) {
48  try {
49  using namespace kaldi;
50  using namespace kaldi::nnet2;
51  typedef kaldi::int32 int32;
52  typedef kaldi::int64 int64;
53 
54  const char *usage =
55  "Copy examples (typically single frames) for neural network training,\n"
56  "possibly changing the binary mode. Supports multiple wspecifiers, in\n"
57  "which case it will write the examples round-robin to the outputs.\n"
58  "\n"
59  "Usage: nnet-copy-egs [options] <egs-rspecifier> <egs-wspecifier1> [<egs-wspecifier2> ...]\n"
60  "\n"
61  "e.g.\n"
62  "nnet-copy-egs ark:train.egs ark,t:text.egs\n"
63  "or:\n"
64  "nnet-copy-egs ark:train.egs ark:1.egs ark:2.egs\n";
65 
66  bool random = false;
67  int32 srand_seed = 0;
68  BaseFloat keep_proportion = 1.0;
69 
70  // The following config variables, if set, can be used to extract a single
71  // frame of labels from a multi-frame example, and/or to reduce the amount
72  // of context.
73  int32 left_context = -1, right_context = -1;
74  // you can set frame to a number to select a single frame with a particular
75  // offset, or to 'random' to select a random single frame.
76  std::string frame_str;
77 
78  ParseOptions po(usage);
79  po.Register("random", &random, "If true, will write frames to output "
80  "archives randomly, not round-robin.");
81  po.Register("keep-proportion", &keep_proportion, "If <1.0, this program will "
82  "randomly keep this proportion of the input samples. If >1.0, it will "
83  "in expectation copy a sample this many times. It will copy it a number "
84  "of times equal to floor(keep-proportion) or ceil(keep-proportion).");
85  po.Register("srand", &srand_seed, "Seed for random number generator "
86  "(only relevant if --random=true or --keep-proportion != 1.0)");
87  po.Register("frame", &frame_str, "This option can be used to select a single "
88  "frame from each multi-frame example. Set to a number 0, 1, etc. "
89  "to select a frame with a given index, or 'random' to select a "
90  "random frame.");
91  po.Register("left-context", &left_context, "Can be used to truncate the "
92  "feature left-context that we output.");
93  po.Register("right-context", &right_context, "Can be used to truncate the "
94  "feature right-context that we output.");
95 
96 
97  po.Read(argc, argv);
98 
99  srand(srand_seed);
100 
101  int32 frame = -1; // -1 means don't do any selection (--frame option unse),
102  // --2 means random selection.
103  if (frame_str != "") {
104  if (!ConvertStringToInteger(frame_str, &frame)) {
105  if (frame_str == "random") {
106  frame = -2;
107  } else {
108  KALDI_ERR << "Invalid --frame option: '" << frame_str << "'";
109  }
110  } else {
111  KALDI_ASSERT(frame >= 0);
112  }
113  }
114  // the following derived variables will be used if the frame, left_context,
115  // or right_context options were set (the frame option will be more common).
116  bool copy_eg = (frame != -1 || left_context != -1 || right_context != -1);
117  int32 start_frame = -1, num_frames = -1;
118  if (frame != -1) { // frame >= 0 or frame == -2 meaning random frame
119  num_frames = 1;
120  start_frame = frame; // value will be ignored if frame == -2.
121  }
122 
123  if (po.NumArgs() < 2) {
124  po.PrintUsage();
125  exit(1);
126  }
127 
128  std::string examples_rspecifier = po.GetArg(1);
129 
130  SequentialNnetExampleReader example_reader(examples_rspecifier);
131 
132  int32 num_outputs = po.NumArgs() - 1;
133  std::vector<NnetExampleWriter*> example_writers(num_outputs);
134  for (int32 i = 0; i < num_outputs; i++)
135  example_writers[i] = new NnetExampleWriter(po.GetArg(i+2));
136 
137 
138  int64 num_read = 0, num_written = 0;
139  for (; !example_reader.Done(); example_reader.Next(), num_read++) {
140  // count is normally 1; could be 0, or possibly >1.
141  int32 count = GetCount(keep_proportion);
142  std::string key = example_reader.Key();
143  const NnetExample &eg = example_reader.Value();
144  for (int32 c = 0; c < count; c++) {
145  int32 index = (random ? Rand() : num_written) % num_outputs;
146  if (!copy_eg) {
147  example_writers[index]->Write(key, eg);
148  num_written++;
149  } else { // the --frame option or related options were set.
150  if (frame == -2) // --frame=random was set -> choose random frame
151  start_frame = RandInt(0, eg.labels.size() - 1);
152  if (start_frame == -1 || start_frame < eg.labels.size()) {
153  // note: we'd only reach here with start_frame == -1 if the
154  // --left-context or --right-context options were set (reducing
155  // context). -1 means use whatever we had in the original eg.
156  NnetExample eg_mod(eg, start_frame, num_frames,
157  left_context, right_context);
158  example_writers[index]->Write(key, eg_mod);
159  num_written++;
160  }
161  // else this frame was out of range for this eg; we don't make this an
162  // error, because it can happen for truncated multi-frame egs that
163  // were created at the end of an utterance.
164  }
165  }
166  }
167 
168  for (int32 i = 0; i < num_outputs; i++)
169  delete example_writers[i];
170  KALDI_LOG << "Read " << num_read << " neural-network training examples, wrote "
171  << num_written;
172  return (num_written == 0 ? 1 : 0);
173  } catch(const std::exception &e) {
174  std::cerr << e.what() << '\n';
175  return -1;
176  }
177 }
178 
179 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool ConvertStringToInteger(const std::string &str, Int *out)
Converts a string into an integer via strtoll and returns false if there was any kind of problem (i...
Definition: text-utils.h:118
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:36
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
bool WithProb(BaseFloat prob, struct RandomState *state)
Definition: kaldi-math.cc:72
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
TableWriter< KaldiObjectHolder< NnetExample > > NnetExampleWriter
Definition: nnet-example.h:92
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
int NumArgs() const
Number of positional parameters (c.f. argc-1).
int main(int argc, char *argv[])
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< std::vector< std::pair< int32, BaseFloat > > > labels
The label(s) for each frame in a sequence of frames; in the normal case, this will be just [ [ (pdf-i...
Definition: nnet-example.h:43
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 GetCount(double expected_count)
Note on how to parse this filename: it contains functions relatied to neural-net training examples...
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95