nnet-example-utils.h
Go to the documentation of this file.
1 // nnet3/nnet-example-utils.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_EXAMPLE_UTILS_H_
21 #define KALDI_NNET3_NNET_EXAMPLE_UTILS_H_
22 
23 #include "nnet3/nnet-example.h"
24 #include "nnet3/nnet-computation.h"
25 #include "nnet3/nnet-compute.h"
26 #include "util/kaldi-table.h"
27 
28 namespace kaldi {
29 namespace nnet3 {
30 
31 
32 
37 void MergeExamples(const std::vector<NnetExample> &src,
38  bool compress,
39  NnetExample *dest);
40 
41 
49 void ShiftExampleTimes(int32 t_offset,
50  const std::vector<std::string> &exclude_names,
51  NnetExample *eg);
52 
60 void GetComputationRequest(const Nnet &nnet,
61  const NnetExample &eg,
62  bool need_model_derivative,
63  bool store_component_stats,
64  ComputationRequest *computation_request);
65 
66 
67 // Writes as unsigned char a vector 'vec' that is required to have
68 // values between 0 and 1.
69 void WriteVectorAsChar(std::ostream &os,
70  bool binary,
71  const VectorBase<BaseFloat> &vec);
72 
73 // Reads data written by WriteVectorAsChar.
74 void ReadVectorAsChar(std::istream &is,
75  bool binary,
76  Vector<BaseFloat> *vec);
77 
78 
79 // Warning: after reading in the values from the command line
80 // (Register() and then then po.Read()), you should then call ComputeDerived()
81 // to set up the 'derived values' (parses 'num_frames_str').
89  std::string num_frames_str;
90 
91 
92  // The following parameters are derived parameters, computed by
93  // ComputeDerived().
94 
95  // the first element of the 'num_frames' vector is the 'principal' number of
96  // frames; the remaining elements are alternatives to the principal number of
97  // frames, to be used at most once or twice per file.
98  std::vector<int32> num_frames;
99 
101  left_context(0), right_context(0),
102  left_context_initial(-1), right_context_final(-1),
103  num_frames_overlap(0), frame_subsampling_factor(1),
104  num_frames_str("1") { }
105 
108  void ComputeDerived();
109 
110  void Register(OptionsItf *po) {
111  po->Register("left-context", &left_context, "Number of frames of left "
112  "context of input features that are added to each "
113  "example");
114  po->Register("right-context", &right_context, "Number of frames of right "
115  "context of input features that are added to each "
116  "example");
117  po->Register("left-context-initial", &left_context_initial, "Number of "
118  "frames of left context of input features that are added to "
119  "each example at the start of the utterance (if <0, this "
120  "defaults to the same as --left-context)");
121  po->Register("right-context-final", &right_context_final, "Number of "
122  "frames of right context of input features that are added "
123  "to each example at the end of the utterance (if <0, this "
124  "defaults to the same as --right-context)");
125  po->Register("num-frames", &num_frames_str, "Number of frames with labels "
126  "that each example contains (i.e. the left and right context "
127  "are to be added to this). May just be an integer (e.g. "
128  "--num-frames=8), or a principal value followed by "
129  "alternative values to be used at most once for each utterance "
130  "to deal with odd-sized input, e.g. --num-frames=40,25,50 means "
131  "that most of the time the number of frames will be 40, but to "
132  "deal with odd-sized inputs we may also generate egs with these "
133  "other sizes. All these values will be rounded up to the "
134  "closest multiple of --frame-subsampling-factor. As a special case, "
135  "--num-frames=-1 means 'don't do any splitting'.");
136  po->Register("num-frames-overlap", &num_frames_overlap, "Number of frames of "
137  "overlap between adjacent eamples (applies to chunks of size "
138  "equal to the primary [first-listed] --num-frames value... "
139  "will be adjusted for different-sized chunks). Advisory; "
140  "will not be exactly enforced.");
141  po->Register("frame-subsampling-factor", &frame_subsampling_factor, "Used "
142  "if the frame-rate of the output labels in the generated "
143  "examples will be less than the frame-rate at the input");
144  }
145 };
146 
147 
148 
158  // The 'output_weights' member is a vector of length equal to the
159  // num_frames divided by frame_subsampling_factor from the config.
160  // It contains values 0 < x <= 1 that represent weightings of
161  // output-frames. The idea is that if (because of overlaps) a
162  // frame appears in multiple chunks, we want to downweight it
163  // so that the total weight remains 1. (Of course, the calling
164  // code is free to ignore these weights if desired).
165  std::vector<BaseFloat> output_weights;
166 };
167 
168 
170  public:
171 
173 
174 
175  const ExampleGenerationConfig& Config() const { return config_; }
176 
177  // Given an utterance length, this function creates for you a list of chunks
178  // into which to split the utterance. Note: this is partly random (will call
179  // srand()).
180  // Accumulates some stats which will be printed out in the destructor.
181  void GetChunksForUtterance(int32 utterance_length,
182  std::vector<ChunkTimeInfo> *chunk_info);
183 
184 
185  // This function returns true if 'supervision_length' (e.g. the length of the
186  // posterior, lattice or alignment) is what we expect given
187  // config_.frame_subsampling_factor. If not, it prints a warning (which is
188  // why the function needs 'utt', and returns false. Note: we round up, so
189  // writing config_.frame_subsampling_factor as sf, we expect
190  // supervision_length = (utterance_length + sf - 1) / sf.
191  bool LengthsMatch(const std::string &utt,
192  int32 utterance_length,
193  int32 supervision_length,
194  int32 length_tolerance = 0) const;
195 
197 
198  int32 ExitStatus() { return (total_frames_in_chunks_ > 0 ? 0 : 1); }
199 
200  private:
201 
202 
203  void InitSplitForLength();
204 
205  // This function returns the 'default duration' in frames of a split, which if
206  // config_.num_frames_overlap is zero is just the sum of chunk sizes in the
207  // split (i.e. the sum of the vector's elements), but otherwise, we subtract
208  // the recommended overlap (see code for details).
209  float DefaultDurationOfSplit(const std::vector<int32> &split) const;
210 
211 
212  // Used in InitSplitForLength(), returns the maximum utterance-length considered
213  // separately in split_for_length_. [above this, we'll assume that the additional
214  // length is consumed by multiples of the 'principal' chunk size.] It returns
215  // the primary chunk-size (config_.num_frames[0]) plus twice the largest of
216  // any of the allowed chunk sizes (i.e. the max of config_.num_frames)
217  int32 MaxUtteranceLength() const;
218 
219  // Used in InitSplitForLength(), this function outputs the set of allowed
220  // splits, represented as a sorted list of nonempty vectors (each split is a
221  // sorted list of chunk-sizes).
222  void InitSplits(std::vector<std::vector<int32> > *splits) const;
223 
224 
225  // Used in GetChunksForUtterance, this function selects the list of
226  // chunk-sizes for that utterance (later on, the positions and and left/right
227  // context information for the chunks will be added to this). We don't call
228  // this a 'split', although it's also a list of chunk-sizes, because we
229  // randomize the order in which the chunk sizes appear, whereas for a 'split'
230  // we sort the chunk-sizes because a 'split' is conceptually an
231  // order-independent representation.
232  void GetChunkSizesForUtterance(int32 utterance_length,
233  std::vector<int32> *chunk_sizes) const;
234 
235 
236  // Used in GetChunksForUtterance, this function selects the 'gap sizes'
237  // before each of the chunks. These 'gap sizes' may be positive (representing
238  // a gap between chunks, or a number of frames at the beginning of the file that
239  // don't correspond to a chunk), or may be negative, corresponding to overlaps
240  // between adjacent chunks.
241  //
242  // If config_.frame_subsampling_factor > 1 and enforce_subsampling_factor is
243  // true, this function will ensure that all elements of 'gap_sizes' are
244  // multiples of config_.frame_subsampling_factor. (we always enforce this,
245  // but we set it to false inside a recursion when we recurse). Note: if
246  // config_.frame_subsampling_factor > 1, it's possible for the last chunk to
247  // go over 'utterance_length' by up to config_.frame_subsampling_factor - 1
248  // frames (i.e. it would require that many frames past the utterance end).
249  // This will be dealt with when generating egs, by duplicating the last frame.
250  void GetGapSizes(int32 utterance_length,
251  bool enforce_subsampling_factor,
252  const std::vector<int32> &chunk_sizes,
253  std::vector<int32> *gap_sizes) const;
254 
255  // this static function, used in GetGapSizes(), writes random values to a
256  // vector 'vec' such the sum of those values equals n (n may be positive or
257  // negative). It tries to make those values as similar as possible (they will
258  // differ by at most one), and the location of the larger versus smaller
259  // values is random. 'vec' must be nonempty.
260  static void DistributeRandomlyUniform(int32 n,
261  std::vector<int32> *vec);
262 
263  // this static function, used in GetGapSizes(), writes values to a vector
264  // 'vec' such the sum of those values equals n (n may be positive or
265  // negative). It tries to make those values, as exactly as it can,
266  // proportional to the values in 'magnitudes', which must be positive. 'vec'
267  // must be nonempty, and 'magnitudes' must be the same size as 'vec'.
268  static void DistributeRandomly(int32 n,
269  const std::vector<int32> &magnitudes,
270  std::vector<int32> *vec);
271 
272  // This function is responsible for setting the 'output_weights'
273  // members of the chunks.
274  void SetOutputWeights(int32 utterance_length,
275  std::vector<ChunkTimeInfo> *chunk_info) const;
276 
277  // Accumulate stats for diagnostics.
278  void AccStatsForUtterance(int32 utterance_length,
279  const std::vector<ChunkTimeInfo> &chunk_info);
280 
281 
283 
284  // The vector 'splits_for_length_' is indexed by the num-frames of a file, and
285  // gives us a list of alternative splits that we can use if the utternace has
286  // that many frames. For example, if split_for_length[100] = ( (25, 40, 40),
287  // (40, 65) ), it means we could either split as chunks of size (25, 40, 40)
288  // or as (40, 65). (we'll later randomize the order). should use one chunk
289  // of size 25 and two chunks of size 40. In general these won't add up to
290  // exactly the length of the utterance; we'll have them overlap (or have small
291  // gaps between them) to account for this, and the details of this will be
292  // randomly decided per file. If splits_for_length_[u] is empty, it means the
293  // utterance was shorter than the smallest possible chunk size, so
294  // we will have to discard the utterance.
295 
296  // If an utterance's num-frames is >= split_for_length.size(), the way to find
297  // the split to use is to keep subtracting the primary num-frames (==
298  // config_.num_frames[0]) minus the num-frames-overlap, from the utterance
299  // length, until the resulting num-frames is < split_for_length_.size(),
300  // chunks, and then add the subtracted number of copies of the primary
301  // num-frames to the split.
302  std::vector<std::vector<std::vector<int32> > > splits_for_length_;
303 
304  // Below are stats used for diagnostics.
305  int32 total_num_utterances_; // total input utterances.
306  int64 total_input_frames_; // total num-frames over all utterances (before
307  // splitting)
308  int64 total_frames_overlap_; // total number of frames that overlap between
309  // adjacent egs.
311  int64 total_frames_in_chunks_; // total of chunk-size times count of that
312  // chunk. equals the num-frames in all the
313  // output chunks, added up.
314  std::map<int32, int32> chunk_size_to_count_; // for each chunk size, gives
315  // the number of chunks with
316  // that size.
317 
318 };
319 
320 
322 public:
323  // The following configuration values are registered on the command line.
324  bool compress;
325  std::string measure_output_frames; // for back-compatibility, not used.
326  std::string minibatch_size;
327  std::string discard_partial_minibatches; // for back-compatibility, not used.
328 
329  ExampleMergingConfig(const char *default_minibatch_size = "256"):
330  compress(false),
331  measure_output_frames("deprecated"),
332  minibatch_size(default_minibatch_size),
333  discard_partial_minibatches("deprecated") { }
334 
335  void Register(OptionsItf *po) {
336  po->Register("compress", &compress, "If true, compress the output examples "
337  "(not recommended unless you are writing to disk)");
338  po->Register("measure-output-frames", &measure_output_frames, "This "
339  "value will be ignored (included for back-compatibility)");
340  po->Register("discard-partial-minibatches", &discard_partial_minibatches,
341  "This value will be ignored (included for back-compatibility)");
342  po->Register("minibatch-size", &minibatch_size,
343  "String controlling the minibatch size. May be just an integer, "
344  "meaning a fixed minibatch size (e.g. --minibatch-size=128). "
345  "May be a list of ranges and values, e.g. --minibatch-size=32,64 "
346  "or --minibatch-size=16:32,64,128. All minibatches will be of "
347  "the largest size until the end of the input is reached; "
348  "then, increasingly smaller sizes will be allowed. Only egs "
349  "with the same structure (e.g num-frames) are merged. You may "
350  "specify different minibatch sizes for different sizes of eg "
351  "(defined as the maximum number of Indexes on any input), in "
352  "the format "
353  "--minibatch-size='eg_size1=mb_sizes1/eg_size2=mb_sizes2', e.g. "
354  "--minibatch-size=128=64:128,256/256=32:64,128. Egs are given "
355  "minibatch-sizes based on the specified eg-size closest to "
356  "their actual size.");
357  }
358 
359 
360  // this function computes the derived (private) parameters; it must be called
361  // after the command-line parameters are read and before MinibatchSize() is
362  // called.
363  void ComputeDerived();
364 
366 
381  int32 MinibatchSize(int32 size_of_eg,
382  int32 num_available_egs,
383  bool input_ended) const;
384 
385 
386  private:
387  // struct IntSet is a representation of something like 16:32,64, which is a
388  // nonempty list of either positive integers or ranges of positive integers.
389  // Conceptually it represents a set of positive integers.
390  struct IntSet {
391  // largest_size is the largest integer in any of the ranges (64 in this
392  // example).
394  // e.g. would contain ((16,32), (64,64)) in this example.
395  std::vector<std::pair<int32, int32> > ranges;
396  // Returns the largest value in any range (i.e. in the set of
397  // integers that this struct represents), that is <= max_value,
398  // or 0 if there is no value in any range that is <= max_value.
399  // In this example, this function would return the following:
400  // 128->64, 64->64, 63->32, 31->31, 16->16, 15->0, 0->0
401  int32 LargestValueInRange(int32 max_value) const;
402  };
403  static bool ParseIntSet(const std::string &str, IntSet *int_set);
404 
405  // 'rules' is derived from the configuration values above by ComputeDerived(),
406  // and are not set directly on the command line. 'rules' is a list of pairs
407  // (eg-size, int-set-of-minibatch-sizes); If no explicit eg-sizes were
408  // specified on the command line (i.e. there was no '=' sign in the
409  // --minibatch-size option), then we just set the int32 to 0.
410  std::vector<std::pair<int32, IntSet> > rules;
411 };
412 
413 
418 
419 
420 
421 
422 
428  public:
437  void WroteExample(int32 example_size, size_t structure_hash,
438  int32 minibatch_size);
439 
443  void DiscardedExamples(int32 example_size, size_t structure_hash,
444  int32 num_discarded);
445 
448  void PrintStats() const;
449 
450  private:
451  // this struct stores the stats for examples of a particular size and
452  // structure.
455  // maps from minibatch-size (i.e. number of egs that were
456  // aggregated into that minibatch), to the number of such
457  // minibatches written.
458  unordered_map<int32, int32> minibatch_to_num_written;
459  StatsForExampleSize(): num_discarded(0) { }
460  };
461 
462 
463  typedef unordered_map<std::pair<int32, size_t>, StatsForExampleSize,
465 
466  // this maps from a pair (example_size, structure_hash) to to the stats for
467  // examples with those characteristics.
468  StatsType stats_;
469 
470  void PrintAggregateStats() const;
471  void PrintSpecificStats() const;
472 
473 };
474 
475 
481  public:
482  ExampleMerger(const ExampleMergingConfig &config,
483  NnetExampleWriter *writer);
484 
485  // This function accepts an example, and if possible, writes a merged example
486  // out. The ownership of the pointer 'a' is transferred to this class when
487  // you call this function.
488  void AcceptExample(NnetExample *a);
489 
490  // This function announces to the class that the input has finished, so it
491  // should flush out any smaller-sized minibatches, as dictated by the config.
492  // This will be called in the destructor, but you can call it explicitly when
493  // all the input is done if you want to; it won't repeat anything if called
494  // twice. It also prints the stats.
495  void Finish();
496 
497  // returns a suitable exit status for a program.
498  int32 ExitStatus() { Finish(); return (num_egs_written_ > 0 ? 0 : 1); }
499 
500  ~ExampleMerger() { Finish(); };
501  private:
502  // called by Finish() and AcceptExample(). Merges, updates the
503  // stats, and writes.
504  void WriteMinibatch(const std::vector<NnetExample> &egs);
505 
506  bool finished_;
511 
512  // Note: the "key" into the egs is the first element of the vector.
513  typedef unordered_map<NnetExample*, std::vector<NnetExample*>,
516  MapType eg_to_egs_;
517 };
518 
519 } // namespace nnet3
520 } // namespace kaldi
521 
522 #endif // KALDI_NNET3_NNET_EXAMPLE_UTILS_H_
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This class is responsible for storing, and displaying in log messages, statistics about how examples ...
unordered_map< std::pair< int32, size_t >, StatsForExampleSize, PairHasher< int32, size_t > > StatsType
std::vector< std::pair< int32, IntSet > > rules
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
std::vector< std::vector< std::vector< int32 > > > splits_for_length_
void ShiftExampleTimes(int32 t_offset, const std::vector< std::string > &exclude_names, NnetExample *eg)
Shifts the time-index t of everything in the "eg" by adding "t_offset" to all "t" values...
const ExampleMergingConfig & config_
This class is responsible for arranging examples in groups that have the same strucure (i...
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
This hashing object hashes just the structural aspects of the NnetExample without looking at the valu...
Definition: nnet-example.h:145
void WriteVectorAsChar(std::ostream &os, bool binary, const VectorBase< BaseFloat > &vec)
struct rnnlm::@11::@12 n
int32 GetNnetExampleSize(const NnetExample &a)
This function returns the &#39;size&#39; of a nnet-example as defined for purposes of merging egs...
ExampleMergingConfig(const char *default_minibatch_size="256")
const ExampleGenerationConfig & config_
unordered_map< NnetExample *, std::vector< NnetExample * >, NnetExampleStructureHasher, NnetExampleStructureCompare > MapType
This comparator object compares just the structural aspects of the NnetExample without looking at the...
Definition: nnet-example.h:159
const ExampleGenerationConfig & Config() const
std::map< int32, int32 > chunk_size_to_count_
void ReadVectorAsChar(std::istream &is, bool binary, Vector< BaseFloat > *vec)
void AccStatsForUtterance(const TransitionModel &trans_model, const AmDiagGmm &am_gmm, const GaussPost &gpost, const Matrix< BaseFloat > &feats, FmllrRawAccs *accs)
struct ChunkTimeInfo is used by class UtteranceSplitter to output information about how we split an u...
std::vector< std::pair< int32, int32 > > ranges
std::vector< BaseFloat > output_weights
void ComputeDerived()
This function decodes &#39;num_frames_str&#39; into &#39;num_frames&#39;, and ensures that the members of &#39;num_frames...
A hashing function-object for pairs of ints.
Definition: stl-utils.h:235
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...
void MergeExamples(const std::vector< NnetExample > &src, bool compress, NnetExample *merged_eg)
Merge a set of input examples into a single example (typically the size of "src" will be the minibatc...