nnet3-get-egs-dense-targets.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-get-egs-dense-targets.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <sstream>
22 
23 #include "base/kaldi-common.h"
24 #include "util/common-utils.h"
25 #include "hmm/transition-model.h"
26 #include "hmm/posterior.h"
27 #include "nnet3/nnet-example.h"
29 
30 namespace kaldi {
31 namespace nnet3 {
32 
33 
34 static bool ProcessFile(const GeneralMatrix &feats,
35  const MatrixBase<BaseFloat> *ivector_feats,
36  int32 ivector_period,
37  const MatrixBase<BaseFloat> &targets,
38  const std::string &utt_id,
39  bool compress,
40  int32 num_targets,
41  int32 length_tolerance,
42  UtteranceSplitter *utt_splitter,
43  NnetExampleWriter *example_writer) {
44  int32 num_input_frames = feats.NumRows();
45  if (!utt_splitter->LengthsMatch(utt_id, num_input_frames,
46  targets.NumRows(),
47  length_tolerance)) {
48  return false;
49  }
50  if (targets.NumRows() == 0)
51  return false;
52  KALDI_ASSERT(num_targets < 0 || targets.NumCols() == num_targets);
53 
54  std::vector<ChunkTimeInfo> chunks;
55 
56  utt_splitter->GetChunksForUtterance(num_input_frames, &chunks);
57 
58  if (chunks.empty()) {
59  KALDI_WARN << "Not producing egs for utterance " << utt_id
60  << " because it is too short: "
61  << num_input_frames << " frames.";
62  return false;
63  }
64 
65  // 'frame_subsampling_factor' is not used in any recipes at the time of
66  // writing, this is being supported to unify the code with the 'chain' recipes
67  // and in case we need it for some reason in future.
68  int32 frame_subsampling_factor =
69  utt_splitter->Config().frame_subsampling_factor;
70 
71  for (size_t c = 0; c < chunks.size(); c++) {
72  const ChunkTimeInfo &chunk = chunks[c];
73 
74  int32 tot_input_frames = chunk.left_context + chunk.num_frames +
75  chunk.right_context;
76 
77  int32 start_frame = chunk.first_frame - chunk.left_context;
78 
79  GeneralMatrix input_frames;
80  ExtractRowRangeWithPadding(feats, start_frame, tot_input_frames,
81  &input_frames);
82 
83  // 'input_frames' now stores the relevant rows (maybe with padding) from the
84  // original Matrix or (more likely) CompressedMatrix. If a CompressedMatrix,
85  // it does this without un-compressing and re-compressing, so there is no loss
86  // of accuracy.
87 
88  NnetExample eg;
89  // call the regular input "input".
90  eg.io.push_back(NnetIo("input", -chunk.left_context, input_frames));
91 
92  if (ivector_feats != NULL) {
93  // if applicable, add the iVector feature.
94  // choose iVector from a random frame in the chunk
95  int32 ivector_frame = RandInt(start_frame,
96  start_frame + num_input_frames - 1),
97  ivector_frame_subsampled = ivector_frame / ivector_period;
98  if (ivector_frame_subsampled < 0)
99  ivector_frame_subsampled = 0;
100  if (ivector_frame_subsampled >= ivector_feats->NumRows())
101  ivector_frame_subsampled = ivector_feats->NumRows() - 1;
102  Matrix<BaseFloat> ivector(1, ivector_feats->NumCols());
103  ivector.Row(0).CopyFromVec(ivector_feats->Row(ivector_frame_subsampled));
104  eg.io.push_back(NnetIo("ivector", 0, ivector));
105  }
106 
107  // Note: chunk.first_frame and chunk.num_frames will both be
108  // multiples of frame_subsampling_factor.
109  int32 start_frame_subsampled = chunk.first_frame / frame_subsampling_factor,
110  num_frames_subsampled = chunk.num_frames / frame_subsampling_factor;
111 
112  KALDI_ASSERT(start_frame_subsampled + num_frames_subsampled - 1 <
113  targets.NumRows());
114 
115 
116  // add the labels.
117  Matrix<BaseFloat> targets_part(num_frames_subsampled, targets.NumCols());
118  for (int32 i = 0; i < num_frames_subsampled; i++) {
119  // Copy the i^th row of the target matrix from the (t+i)^th row of the
120  // input targets matrix
121  int32 t = i + start_frame_subsampled;
122  if (t >= targets.NumRows())
123  t = targets.NumRows() - 1;
124  SubVector<BaseFloat> this_target_dest(targets_part, i);
125  SubVector<BaseFloat> this_target_src(targets, t);
126  this_target_dest.CopyFromVec(this_target_src);
127  }
128 
129  // push this created targets matrix into the eg
130  eg.io.push_back(NnetIo("output", 0, targets_part, frame_subsampling_factor));
131 
132  if (compress)
133  eg.Compress();
134 
135  std::ostringstream os;
136  os << utt_id << "-" << chunk.first_frame;
137 
138  std::string key = os.str(); // key is <utt_id>-<frame_id>
139 
140  example_writer->Write(key, eg);
141  }
142  return true;
143 }
144 
145 } // namespace nnet3
146 } // namespace kaldi
147 
148 int main(int argc, char *argv[]) {
149  try {
150  using namespace kaldi;
151  using namespace kaldi::nnet3;
152  typedef kaldi::int32 int32;
153  typedef kaldi::int64 int64;
154 
155  const char *usage =
156  "Get frame-by-frame examples of data for nnet3 neural network training.\n"
157  "This program is similar to nnet3-get-egs, but the targets here are "
158  "dense matrices instead of posteriors (sparse matrices).\n"
159  "This is useful when you want the targets to be continuous real-valued "
160  "with the neural network possibly trained with a quadratic objective\n"
161  "\n"
162  "Usage: nnet3-get-egs-dense-targets --num-targets=<n> [options] "
163  "<features-rspecifier> <targets-rspecifier> <egs-out>\n"
164  "\n"
165  "An example [where $feats expands to the actual features]:\n"
166  "nnet-get-egs-dense-targets --num-targets=26 --left-context=12 \\\n"
167  "--right-context=9 --num-frames=8 \"$feats\" \\\n"
168  "\"ark:copy-matrix ark:exp/snrs/snr.1.ark ark:- |\"\n"
169  " ark:- \n";
170 
171 
172  bool compress = true;
173  int32 num_targets = -1, length_tolerance = 100,
174  targets_length_tolerance = 2,
175  online_ivector_period = 1;
176 
177  ExampleGenerationConfig eg_config; // controls num-frames,
178  // left/right-context, etc.
179 
180  std::string online_ivector_rspecifier;
181 
182  ParseOptions po(usage);
183 
184  po.Register("compress", &compress, "If true, write egs with input features "
185  "in compressed format (recommended). This is "
186  "only relevant if the features being read are un-compressed; "
187  "if already compressed, we keep we same compressed format when "
188  "dumping egs.");
189  po.Register("num-targets", &num_targets, "Output dimension in egs, "
190  "only used to check targets have correct dim if supplied.");
191  po.Register("ivectors", &online_ivector_rspecifier, "Alias for "
192  "--online-ivectors option, for back compatibility");
193  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier of "
194  "ivector features, as a matrix.");
195  po.Register("online-ivector-period", &online_ivector_period, "Number of "
196  "frames between iVectors in matrices supplied to the "
197  "--online-ivectors option");
198  po.Register("length-tolerance", &length_tolerance, "Tolerance for "
199  "difference in num-frames between feat and ivector matrices");
200  po.Register("targets-length-tolerance", &targets_length_tolerance,
201  "Tolerance for "
202  "difference in num-frames (after subsampling) between "
203  "feature and target matrices");
204  eg_config.Register(&po);
205 
206  po.Read(argc, argv);
207 
208  if (po.NumArgs() != 3) {
209  po.PrintUsage();
210  exit(1);
211  }
212 
213  eg_config.ComputeDerived();
214  UtteranceSplitter utt_splitter(eg_config);
215 
216  std::string feature_rspecifier = po.GetArg(1),
217  matrix_rspecifier = po.GetArg(2),
218  examples_wspecifier = po.GetArg(3);
219 
220  // SequentialGeneralMatrixReader can read either a Matrix or
221  // CompressedMatrix (or SparseMatrix, but not as relevant here),
222  // and it retains the type. This way, we can generate parts of
223  // the feature matrices without uncompressing and re-compressing.
224  SequentialGeneralMatrixReader feat_reader(feature_rspecifier);
225  RandomAccessBaseFloatMatrixReader matrix_reader(matrix_rspecifier);
226  NnetExampleWriter example_writer(examples_wspecifier);
227  RandomAccessBaseFloatMatrixReader online_ivector_reader(
228  online_ivector_rspecifier);
229 
230  int32 num_err = 0;
231 
232  for (; !feat_reader.Done(); feat_reader.Next()) {
233  std::string key = feat_reader.Key();
234  const GeneralMatrix &feats = feat_reader.Value();
235  if (!matrix_reader.HasKey(key)) {
236  KALDI_WARN << "No target matrix for key " << key;
237  num_err++;
238  } else {
239  const Matrix<BaseFloat> &target_matrix = matrix_reader.Value(key);
240  const Matrix<BaseFloat> *online_ivector_feats = NULL;
241  if (!online_ivector_rspecifier.empty()) {
242  if (!online_ivector_reader.HasKey(key)) {
243  KALDI_WARN << "No iVectors for utterance " << key;
244  num_err++;
245  continue;
246  } else {
247  // this address will be valid until we call HasKey() or Value()
248  // again.
249  online_ivector_feats = &(online_ivector_reader.Value(key));
250  }
251  }
252 
253  if (online_ivector_feats != NULL &&
254  (abs(feats.NumRows() - (online_ivector_feats->NumRows() *
255  online_ivector_period)) > length_tolerance
256  || online_ivector_feats->NumRows() == 0)) {
257  KALDI_WARN << "Length difference between feats " << feats.NumRows()
258  << " and iVectors " << online_ivector_feats->NumRows()
259  << "exceeds tolerance " << length_tolerance;
260  num_err++;
261  continue;
262  }
263 
264  if (!ProcessFile(feats, online_ivector_feats, online_ivector_period,
265  target_matrix, key, compress, num_targets,
266  targets_length_tolerance,
267  &utt_splitter, &example_writer))
268  num_err++;
269  }
270  }
271  if (num_err > 0)
272  KALDI_WARN << num_err << " utterances had errors and could "
273  "not be processed.";
274  // utt_splitter prints stats in its destructor.
275  return utt_splitter.ExitStatus();
276  } catch(const std::exception &e) {
277  std::cerr << e.what() << '\n';
278  return -1;
279  }
280 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
bool LengthsMatch(const std::string &utt, int32 utterance_length, int32 supervision_length, int32 length_tolerance=0) const
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
static bool ProcessFile(const discriminative::SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const MatrixBase< BaseFloat > &feats, const MatrixBase< BaseFloat > *ivector_feats, int32 ivector_period, const discriminative::DiscriminativeSupervision &supervision, const std::string &utt_id, bool compress, UtteranceSplitter *utt_splitter, NnetDiscriminativeExampleWriter *example_writer)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void ExtractRowRangeWithPadding(const GeneralMatrix &in, int32 row_offset, int32 num_rows, GeneralMatrix *out)
This function extracts a row-range of a GeneralMatrix and writes as a GeneralMatrix containing the sa...
void Compress()
Compresses any (input) features that are not sparse.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
const ExampleGenerationConfig & Config() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
MatrixIndexT NumRows() const
void GetChunksForUtterance(int32 utterance_length, std::vector< ChunkTimeInfo > *chunk_info)
struct ChunkTimeInfo is used by class UtteranceSplitter to output information about how we split an u...
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void ComputeDerived()
This function decodes &#39;num_frames_str&#39; into &#39;num_frames&#39;, and ensures that the members of &#39;num_frames...