All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet3-get-egs.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-get-egs.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <sstream>
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "hmm/transition-model.h"
25 #include "hmm/posterior.h"
26 #include "nnet3/nnet-example.h"
28 
29 namespace kaldi {
30 namespace nnet3 {
31 
32 
33 static bool ProcessFile(const GeneralMatrix &feats,
34  const MatrixBase<BaseFloat> *ivector_feats,
35  int32 ivector_period,
36  const Posterior &pdf_post,
37  const std::string &utt_id,
38  bool compress,
39  int32 num_pdfs,
40  int32 length_tolerance,
41  UtteranceSplitter *utt_splitter,
42  NnetExampleWriter *example_writer) {
43  int32 num_input_frames = feats.NumRows();
44  if (!utt_splitter->LengthsMatch(utt_id, num_input_frames,
45  static_cast<int32>(pdf_post.size()),
46  length_tolerance))
47  return false; // LengthsMatch() will have printed a warning.
48 
49  std::vector<ChunkTimeInfo> chunks;
50 
51  utt_splitter->GetChunksForUtterance(num_input_frames, &chunks);
52 
53  if (chunks.empty()) {
54  KALDI_WARN << "Not producing egs for utterance " << utt_id
55  << " because it is too short: "
56  << num_input_frames << " frames.";
57  }
58 
59  // 'frame_subsampling_factor' is not used in any recipes at the time of
60  // writing, this is being supported to unify the code with the 'chain' recipes
61  // and in case we need it for some reason in future.
62  int32 frame_subsampling_factor =
63  utt_splitter->Config().frame_subsampling_factor;
64 
65  for (size_t c = 0; c < chunks.size(); c++) {
66  const ChunkTimeInfo &chunk = chunks[c];
67 
68  int32 tot_input_frames = chunk.left_context + chunk.num_frames +
69  chunk.right_context;
70 
71  int32 start_frame = chunk.first_frame - chunk.left_context;
72 
73  GeneralMatrix input_frames;
74  ExtractRowRangeWithPadding(feats, start_frame, tot_input_frames,
75  &input_frames);
76 
77  // 'input_frames' now stores the relevant rows (maybe with padding) from the
78  // original Matrix or (more likely) CompressedMatrix. If a CompressedMatrix,
79  // it does this without un-compressing and re-compressing, so there is no loss
80  // of accuracy.
81 
82  NnetExample eg;
83  // call the regular input "input".
84  eg.io.push_back(NnetIo("input", -chunk.left_context, input_frames));
85 
86  if (ivector_feats != NULL) {
87  // if applicable, add the iVector feature.
88  // choose iVector from a random frame in the chunk
89  int32 ivector_frame = RandInt(start_frame,
90  start_frame + num_input_frames - 1),
91  ivector_frame_subsampled = ivector_frame / ivector_period;
92  if (ivector_frame_subsampled < 0)
93  ivector_frame_subsampled = 0;
94  if (ivector_frame_subsampled >= ivector_feats->NumRows())
95  ivector_frame_subsampled = ivector_feats->NumRows() - 1;
96  Matrix<BaseFloat> ivector(1, ivector_feats->NumCols());
97  ivector.Row(0).CopyFromVec(ivector_feats->Row(ivector_frame_subsampled));
98  eg.io.push_back(NnetIo("ivector", 0, ivector));
99  }
100 
101  // Note: chunk.first_frame and chunk.num_frames will both be
102  // multiples of frame_subsampling_factor.
103  int32 start_frame_subsampled = chunk.first_frame / frame_subsampling_factor,
104  num_frames_subsampled = chunk.num_frames / frame_subsampling_factor;
105 
106  Posterior labels(num_frames_subsampled);
107 
108  // TODO: it may be that using these weights is not actually helpful (with
109  // chain training, it was not), and that setting them all to 1 is better.
110  // We could add a boolean option to this program to control that; but I
111  // don't want to add such an option if experiments show that it is not
112  // helpful.
113  for (int32 i = 0; i < num_frames_subsampled; i++) {
114  int32 t = i + start_frame_subsampled;
115  if (t < pdf_post.size())
116  labels[i] = pdf_post[t];
117  for (std::vector<std::pair<int32, BaseFloat> >::iterator
118  iter = labels[i].begin(); iter != labels[i].end(); ++iter)
119  iter->second *= chunk.output_weights[i];
120  }
121 
122  eg.io.push_back(NnetIo("output", num_pdfs, 0, labels, frame_subsampling_factor));
123 
124  if (compress)
125  eg.Compress();
126 
127  std::ostringstream os;
128  os << utt_id << "-" << chunk.first_frame;
129 
130  std::string key = os.str(); // key is <utt_id>-<frame_id>
131 
132  example_writer->Write(key, eg);
133  }
134  return true;
135 }
136 
137 } // namespace nnet3
138 } // namespace kaldi
139 
140 int main(int argc, char *argv[]) {
141  try {
142  using namespace kaldi;
143  using namespace kaldi::nnet3;
144  typedef kaldi::int32 int32;
145  typedef kaldi::int64 int64;
146 
147  const char *usage =
148  "Get frame-by-frame examples of data for nnet3 neural network training.\n"
149  "Essentially this is a format change from features and posteriors\n"
150  "into a special frame-by-frame format. This program handles the\n"
151  "common case where you have some input features, possibly some\n"
152  "iVectors, and one set of labels. If people in future want to\n"
153  "do different things they may have to extend this program or create\n"
154  "different versions of it for different tasks (the egs format is quite\n"
155  "general)\n"
156  "\n"
157  "Usage: nnet3-get-egs [options] <features-rspecifier> "
158  "<pdf-post-rspecifier> <egs-out>\n"
159  "\n"
160  "An example [where $feats expands to the actual features]:\n"
161  "nnet3-get-egs --num-pdfs=2658 --left-context=12 --right-context=9 --num-frames=8 \"$feats\"\\\n"
162  "\"ark:gunzip -c exp/nnet/ali.1.gz | ali-to-pdf exp/nnet/1.nnet ark:- ark:- | ali-to-post ark:- ark:- |\" \\\n"
163  " ark:- \n"
164  "See also: nnet3-chain-get-egs, nnet3-get-egs-simple\n";
165 
166 
167  bool compress = true;
168  int32 num_pdfs = -1, length_tolerance = 100,
169  targets_length_tolerance = 2,
170  online_ivector_period = 1;
171 
172  ExampleGenerationConfig eg_config; // controls num-frames,
173  // left/right-context, etc.
174 
175  std::string online_ivector_rspecifier;
176 
177  ParseOptions po(usage);
178 
179  po.Register("compress", &compress, "If true, write egs with input features "
180  "in compressed format (recommended). This is "
181  "only relevant if the features being read are un-compressed; "
182  "if already compressed, we keep we same compressed format when "
183  "dumping egs.");
184  po.Register("num-pdfs", &num_pdfs, "Number of pdfs in the acoustic "
185  "model");
186  po.Register("ivectors", &online_ivector_rspecifier, "Alias for "
187  "--online-ivectors option, for back compatibility");
188  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier of "
189  "ivector features, as a matrix.");
190  po.Register("online-ivector-period", &online_ivector_period, "Number of "
191  "frames between iVectors in matrices supplied to the "
192  "--online-ivectors option");
193  po.Register("length-tolerance", &length_tolerance, "Tolerance for "
194  "difference in num-frames between feat and ivector matrices");
195  po.Register("targets-length-tolerance", &targets_length_tolerance,
196  "Tolerance for "
197  "difference in num-frames (after subsampling) between "
198  "feature matrix and posterior");
199  eg_config.Register(&po);
200 
201  po.Read(argc, argv);
202 
203  if (po.NumArgs() != 3) {
204  po.PrintUsage();
205  exit(1);
206  }
207 
208  if (num_pdfs <= 0)
209  KALDI_ERR << "--num-pdfs options is required.";
210 
211  eg_config.ComputeDerived();
212  UtteranceSplitter utt_splitter(eg_config);
213 
214  std::string feature_rspecifier = po.GetArg(1),
215  pdf_post_rspecifier = po.GetArg(2),
216  examples_wspecifier = po.GetArg(3);
217 
218  // SequentialGeneralMatrixReader can read either a Matrix or
219  // CompressedMatrix (or SparseMatrix, but not as relevant here),
220  // and it retains the type. This way, we can generate parts of
221  // the feature matrices without uncompressing and re-compressing.
222  SequentialGeneralMatrixReader feat_reader(feature_rspecifier);
223  RandomAccessPosteriorReader pdf_post_reader(pdf_post_rspecifier);
224  NnetExampleWriter example_writer(examples_wspecifier);
225  RandomAccessBaseFloatMatrixReader online_ivector_reader(
226  online_ivector_rspecifier);
227 
228  int32 num_err = 0;
229 
230  for (; !feat_reader.Done(); feat_reader.Next()) {
231  std::string key = feat_reader.Key();
232  const GeneralMatrix &feats = feat_reader.Value();
233  if (!pdf_post_reader.HasKey(key)) {
234  KALDI_WARN << "No pdf-level posterior for key " << key;
235  num_err++;
236  } else {
237  const Posterior &pdf_post = pdf_post_reader.Value(key);
238  const Matrix<BaseFloat> *online_ivector_feats = NULL;
239  if (!online_ivector_rspecifier.empty()) {
240  if (!online_ivector_reader.HasKey(key)) {
241  KALDI_WARN << "No iVectors for utterance " << key;
242  num_err++;
243  continue;
244  } else {
245  // this address will be valid until we call HasKey() or Value()
246  // again.
247  online_ivector_feats = &(online_ivector_reader.Value(key));
248  }
249  }
250 
251  if (online_ivector_feats != NULL &&
252  (abs(feats.NumRows() - (online_ivector_feats->NumRows() *
253  online_ivector_period)) > length_tolerance
254  || online_ivector_feats->NumRows() == 0)) {
255  KALDI_WARN << "Length difference between feats " << feats.NumRows()
256  << " and iVectors " << online_ivector_feats->NumRows()
257  << "exceeds tolerance " << length_tolerance;
258  num_err++;
259  continue;
260  }
261 
262  if (!ProcessFile(feats, online_ivector_feats, online_ivector_period,
263  pdf_post, key, compress, num_pdfs,
264  targets_length_tolerance,
265  &utt_splitter, &example_writer))
266  num_err++;
267  }
268  }
269  if (num_err > 0)
270  KALDI_WARN << num_err << " utterances had errors and could "
271  "not be processed.";
272  // utt_splitter prints stats in its destructor.
273  return utt_splitter.ExitStatus();
274  } catch(const std::exception &e) {
275  std::cerr << e.what() << '\n';
276  return -1;
277  }
278 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
bool LengthsMatch(const std::string &utt, int32 utterance_length, int32 supervision_length, int32 length_tolerance=0) const
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
static bool ProcessFile(const discriminative::SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const MatrixBase< BaseFloat > &feats, const MatrixBase< BaseFloat > *ivector_feats, int32 ivector_period, const discriminative::DiscriminativeSupervision &supervision, const std::string &utt_id, bool compress, UtteranceSplitter *utt_splitter, NnetDiscriminativeExampleWriter *example_writer)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void ExtractRowRangeWithPadding(const GeneralMatrix &in, int32 row_offset, int32 num_rows, GeneralMatrix *out)
This function extracts a row-range of a GeneralMatrix and writes as a GeneralMatrix containing the sa...
#define KALDI_ERR
Definition: kaldi-error.h:147
void Compress()
Compresses any (input) features that are not sparse.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
const ExampleGenerationConfig & Config() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
MatrixIndexT NumRows() const
void GetChunksForUtterance(int32 utterance_length, std::vector< ChunkTimeInfo > *chunk_info)
struct ChunkTimeInfo is used by class UtteranceSplitter to output information about how we split an u...
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
int main(int argc, char *argv[])
std::vector< BaseFloat > output_weights
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void ComputeDerived()
This function decodes &#39;num_frames_str&#39; into &#39;num_frames&#39;, and ensures that the members of &#39;num_frames...