nnet-get-weighted-egs.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-get-weighted-egs.cc
2 
3 // Copyright 2013-2014 (Author: Vimal Manohar)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
24 
25 namespace kaldi {
26 namespace nnet2 {
27 
28 // returns an integer randomly drawn with expected value "expected_count"
29 // (will be either floor(expected_count) or ceil(expected_count)).
30 // this will go into an infinite loop if expected_count is very huge, but
31 // it should never be that huge.
32 // In the normal case, "expected_count" will be between zero and one.
33 int32 GetCount(double expected_count) {
34  KALDI_ASSERT(expected_count >= 0.0);
35  int32 ans = 0;
36  while (expected_count > 1.0) {
37  ans++;
38  expected_count--;
39  }
40  if (WithProb(expected_count))
41  ans++;
42  return ans;
43 }
44 
45 static void ProcessFile(const MatrixBase<BaseFloat> &feats,
46  const Posterior &pdf_post,
47  const std::string &utt_id,
48  const Vector<BaseFloat> &weights,
49  int32 left_context,
50  int32 right_context,
51  int32 const_feat_dim,
52  BaseFloat keep_proportion,
53  BaseFloat weight_threshold,
54  bool use_frame_selection,
55  bool use_frame_weights,
56  int64 *num_frames_written,
57  int64 *num_frames_skipped,
58  NnetExampleWriter *example_writer) {
59  KALDI_ASSERT(feats.NumRows() == static_cast<int32>(pdf_post.size()));
60  int32 feat_dim = feats.NumCols();
61  KALDI_ASSERT(const_feat_dim < feat_dim);
62  int32 basic_feat_dim = feat_dim - const_feat_dim;
63  NnetExample eg;
64  Matrix<BaseFloat> input_frames(left_context + 1 + right_context,
65  basic_feat_dim);
66  eg.left_context = left_context;
67  // TODO: modify this code, and this binary itself, to support the --num-frames
68  // option to allow multiple frames per eg.
69  for (int32 i = 0; i < feats.NumRows(); i++) {
70  int32 count = GetCount(keep_proportion); // number of times
71  // we'll write this out (1 by default).
72  if (count > 0) {
73  // Set up "input_frames".
74  for (int32 j = -left_context; j <= right_context; j++) {
75  int32 j2 = j + i;
76  if (j2 < 0) j2 = 0;
77  if (j2 >= feats.NumRows()) j2 = feats.NumRows() - 1;
78  SubVector<BaseFloat> src(feats, j2), dest(input_frames,
79  j + left_context);
80  dest.CopyFromVec(src);
81  }
82  eg.labels.push_back(pdf_post[i]);
83  eg.input_frames = input_frames;
84  if (const_feat_dim > 0) {
85  // we'll normally reach here if we're using online-estimated iVectors.
86  SubVector<BaseFloat> const_part(feats.Row(i),
87  basic_feat_dim, const_feat_dim);
88  eg.spk_info.CopyFromVec(const_part);
89  }
90  if (use_frame_selection) {
91  if (weights(i) < weight_threshold) {
92  (*num_frames_skipped)++;
93  continue;
94  }
95  }
96  std::ostringstream os;
97  os << utt_id << "-" << i;
98  std::string key = os.str(); // key in the archive is the number of the example
99 
100  for (int32 c = 0; c < count; c++)
101  example_writer->Write(key, eg);
102  }
103  }
104 }
105 
106 
107 } // namespace nnet2
108 } // namespace kaldi
109 
110 int main(int argc, char *argv[]) {
111  try {
112  using namespace kaldi;
113  using namespace kaldi::nnet2;
114  typedef kaldi::int32 int32;
115  typedef kaldi::int64 int64;
116 
117  const char *usage =
118  "Get frame-by-frame examples of data for neural network training.\n"
119  "Essentially this is a format change from features and posteriors\n"
120  "into a special frame-by-frame format. To split randomly into\n"
121  "different subsets, do nnet-copy-egs with --random=true, but\n"
122  "note that this does not randomize the order of frames.\n"
123  "\n"
124  "Usage: nnet-get-weighted-egs [options] <features-rspecifier> "
125  "<pdf-post-rspecifier> <weights-rspecifier> <training-examples-out>\n"
126  "\n"
127  "An example [where $feats expands to the actual features]:\n"
128  "nnet-get-weighted-egs --left-context=8 --right-context=8 \"$feats\" \\\n"
129  " \"ark:gunzip -c exp/nnet/ali.1.gz | ali-to-pdf exp/nnet/1.nnet ark:- ark:- | ali-to-post ark:- ark:- |\" \\\n"
130  " ark:- \n"
131  "Note: the --left-context and --right-context would be derived from\n"
132  "the output of nnet-info.";
133 
134 
135  int32 left_context = 0, right_context = 0, const_feat_dim = 0;
136  int32 srand_seed = 0;
137  BaseFloat keep_proportion = 1.0;
138  BaseFloat weight_threshold = 0.0;
139  bool use_frame_selection = true, use_frame_weights=false;
140 
141  ParseOptions po(usage);
142  po.Register("left-context", &left_context, "Number of frames of left context "
143  "the neural net requires.");
144  po.Register("right-context", &right_context, "Number of frames of right context "
145  "the neural net requires.");
146  po.Register("const-feat-dim", &const_feat_dim, "If specified, the last "
147  "const-feat-dim dimensions of the feature input are treated as "
148  "constant over the context window (so are not spliced)");
149  po.Register("keep-proportion", &keep_proportion, "If <1.0, this program will "
150  "randomly keep this proportion of the input samples. If >1.0, it will "
151  "in expectation copy a sample this many times. It will copy it a number "
152  "of times equal to floor(keep-proportion) or ceil(keep-proportion).");
153  po.Register("srand", &srand_seed, "Seed for random number generator "
154  "(only relevant if --keep-proportion != 1.0)");
155  po.Register("weight-threshold", &weight_threshold, "Keep only frames with weights "
156  "above this threshold.");
157  po.Register("use-frame-selection", &use_frame_selection, "Remove the frames below threshold.");
158  po.Register("use-frame-weights", &use_frame_weights, "Scale the error derivatives by the weight");
159 
160  po.Read(argc, argv);
161 
162  srand(srand_seed);
163 
164  if (po.NumArgs() != 4) {
165  po.PrintUsage();
166  exit(1);
167  }
168 
169  std::string feature_rspecifier = po.GetArg(1),
170  pdf_post_rspecifier = po.GetArg(2),
171  weights_rspecifier = po.GetArg(3),
172  examples_wspecifier = po.GetArg(4);
173 
174  // Read in all the training files.
175  SequentialBaseFloatMatrixReader feat_reader(feature_rspecifier);
176  RandomAccessPosteriorReader pdf_post_reader(pdf_post_rspecifier);
177  RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier);
178  NnetExampleWriter example_writer(examples_wspecifier);
179 
180  int32 num_done = 0, num_err = 0;
181  int64 num_frames_written = 0;
182  int64 num_frames_skipped = 0;
183 
184  for (; !feat_reader.Done(); feat_reader.Next()) {
185  std::string key = feat_reader.Key();
186  const Matrix<BaseFloat> &feats = feat_reader.Value();
187  if (!pdf_post_reader.HasKey(key)) {
188  KALDI_WARN << "No pdf-level posterior for key " << key;
189  num_err++;
190  } else {
191  const Posterior &pdf_post = pdf_post_reader.Value(key);
192  if (pdf_post.size() != feats.NumRows()) {
193  KALDI_WARN << "Posterior has wrong size " << pdf_post.size()
194  << " versus " << feats.NumRows();
195  num_err++;
196  continue;
197  }
198  if (!weights_reader.HasKey(key)) {
199  KALDI_ERR << "No weights for utterance " << key;
200  //ProcessFile(feats, pdf_post, NULL,
201  // left_context, right_context, const_feat_dim, keep_proportion,
202  // weight_threshold, false, false, &num_frames_written,
203  // &num_frames_skipped, &example_writer);
204  } else {
205  Vector<BaseFloat> weights = weights_reader.Value(key);
206  if (weights.Dim() != static_cast<int32>(pdf_post.size())) {
207  KALDI_WARN << "Weights for utterance " << key
208  << " have wrong size, " << weights.Dim()
209  << " vs. " << pdf_post.size();
210  num_err++;
211  continue;
212  }
213  ProcessFile(feats, pdf_post, key, weights, left_context, right_context,
214  const_feat_dim, keep_proportion, weight_threshold,
215  use_frame_selection, use_frame_weights,
216  &num_frames_written, &num_frames_skipped, &example_writer);
217  }
218  num_done++;
219  }
220  }
221 
222  KALDI_LOG << "Finished generating examples, "
223  << "successfully processed " << num_done
224  << " feature files, wrote " << num_frames_written << " examples, "
225  << "skipped " << num_frames_skipped << " examples, "
226  << num_err << " files had errors.";
227  return (num_done == 0 ? 1 : 0);
228  } catch(const std::exception &e) {
229  std::cerr << e.what() << '\n';
230  return -1;
231  }
232 }
CompressedMatrix input_frames
The input data, with NumRows() >= labels.size() + left_context; it includes features to the left and ...
Definition: nnet-example.h:49
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:36
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 left_context
The number of frames of left context (we can work out the #frames of right context from input_frames...
Definition: nnet-example.h:53
bool WithProb(BaseFloat prob, struct RandomState *state)
Definition: kaldi-math.cc:72
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
int main(int argc, char *argv[])
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
std::vector< std::vector< std::pair< int32, BaseFloat > > > labels
The label(s) for each frame in a sequence of frames; in the normal case, this will be just [ [ (pdf-i...
Definition: nnet-example.h:43
static void ProcessFile(const MatrixBase< BaseFloat > &feats, const Posterior &pdf_post, const std::string &utt_id, int32 left_context, int32 right_context, int32 num_frames, int32 const_feat_dim, int64 *num_frames_written, int64 *num_egs_written, NnetExampleWriter *example_writer)
Definition: nnet-get-egs.cc:32
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 GetCount(double expected_count)
Note on how to parse this filename: it contains functions relatied to neural-net training examples...
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
Vector< BaseFloat > spk_info
The speaker-specific input, if any, or an empty vector if we&#39;re not using this features.
Definition: nnet-example.h:58