nnet-chain-example.cc
Go to the documentation of this file.
1 // nnet3/nnet-chain-example.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <cmath>
23 
24 namespace kaldi {
25 namespace nnet3 {
26 
27 
28 void NnetChainSupervision::Write(std::ostream &os, bool binary) const {
29  CheckDim();
30  WriteToken(os, binary, "<NnetChainSup>");
31  WriteToken(os, binary, name);
32  WriteIndexVector(os, binary, indexes);
33  supervision.Write(os, binary);
34  WriteToken(os, binary, "<DW2>");
35  deriv_weights.Write(os, binary);
36  WriteToken(os, binary, "</NnetChainSup>");
37 }
38 
40  return name == other.name && indexes == other.indexes &&
41  supervision == other.supervision &&
42  deriv_weights.ApproxEqual(other.deriv_weights);
43 }
44 
45 void NnetChainSupervision::Read(std::istream &is, bool binary) {
46  ExpectToken(is, binary, "<NnetChainSup>");
47  ReadToken(is, binary, &name);
48  ReadIndexVector(is, binary, &indexes);
49  supervision.Read(is, binary);
50  std::string token;
51  ReadToken(is, binary, &token);
52  // in the future this back-compatibility code can be reworked.
53  if (token != "</NnetChainSup>") {
54  KALDI_ASSERT(token == "<DW>" || token == "<DW2>");
55  if (token == "<DW>")
56  ReadVectorAsChar(is, binary, &deriv_weights);
57  else
58  deriv_weights.Read(is, binary);
59  ExpectToken(is, binary, "</NnetChainSup>");
60  }
61  CheckDim();
62 }
63 
64 
66  if (supervision.frames_per_sequence == -1) {
67  // this object has not been set up.
68  KALDI_ASSERT(indexes.empty());
69  return;
70  }
71  KALDI_ASSERT(indexes.size() == supervision.num_sequences *
72  supervision.frames_per_sequence && !indexes.empty() &&
73  supervision.frames_per_sequence > 1);
74  int32 first_frame = indexes[0].t,
75  frame_skip = indexes[supervision.num_sequences].t - first_frame,
76  num_sequences = supervision.num_sequences,
77  frames_per_sequence = supervision.frames_per_sequence;
78  int32 k = 0;
79  for (int32 i = 0; i < frames_per_sequence; i++) {
80  for (int32 j = 0; j < num_sequences; j++,k++) {
81  int32 n = j, t = i * frame_skip + first_frame, x = 0;
82  Index index(n, t, x);
83  KALDI_ASSERT(indexes[k] == index);
84  }
85  }
86  if (deriv_weights.Dim() != 0) {
87  KALDI_ASSERT(deriv_weights.Dim() == indexes.size());
88  KALDI_ASSERT(deriv_weights.Min() >= 0.0);
89  }
90 }
91 
93  name(other.name),
94  indexes(other.indexes),
95  supervision(other.supervision),
97 
99  name.swap(other->name);
100  indexes.swap(other->indexes);
101  supervision.Swap(&(other->supervision));
102  deriv_weights.Swap(&(other->deriv_weights));
103  if (RandInt(0, 5) == 0)
104  CheckDim();
105 }
106 
108  const std::string &name,
109  const chain::Supervision &supervision,
111  int32 first_frame,
112  int32 frame_skip):
113  name(name),
114  supervision(supervision),
115  deriv_weights(deriv_weights) {
116  // note: this will set the 'x' index to zero.
117  indexes.resize(supervision.num_sequences *
118  supervision.frames_per_sequence);
119  int32 k = 0, num_sequences = supervision.num_sequences,
120  frames_per_sequence = supervision.frames_per_sequence;
121  for (int32 i = 0; i < frames_per_sequence; i++) {
122  for (int32 j = 0; j < num_sequences; j++,k++) {
123  indexes[k].n = j;
124  indexes[k].t = i * frame_skip + first_frame;
125  }
126  }
127  KALDI_ASSERT(k == indexes.size());
128  CheckDim();
129 }
130 
131 
132 void NnetChainExample::Write(std::ostream &os, bool binary) const {
133  // Note: weight, label, input_frames and spk_info are members. This is a
134  // struct.
135  WriteToken(os, binary, "<Nnet3ChainEg>");
136  WriteToken(os, binary, "<NumInputs>");
137  int32 size = inputs.size();
138  WriteBasicType(os, binary, size);
139  KALDI_ASSERT(size > 0 && "Attempting to write NnetChainExample with no inputs");
140  if (!binary) os << '\n';
141  for (int32 i = 0; i < size; i++) {
142  inputs[i].Write(os, binary);
143  if (!binary) os << '\n';
144  }
145  WriteToken(os, binary, "<NumOutputs>");
146  size = outputs.size();
147  WriteBasicType(os, binary, size);
148  KALDI_ASSERT(size > 0 && "Attempting to write NnetChainExample with no outputs");
149  if (!binary) os << '\n';
150  for (int32 i = 0; i < size; i++) {
151  outputs[i].Write(os, binary);
152  if (!binary) os << '\n';
153  }
154  WriteToken(os, binary, "</Nnet3ChainEg>");
155 }
156 
157 void NnetChainExample::Read(std::istream &is, bool binary) {
158  ExpectToken(is, binary, "<Nnet3ChainEg>");
159  ExpectToken(is, binary, "<NumInputs>");
160  int32 size;
161  ReadBasicType(is, binary, &size);
162  if (size < 1 || size > 1000000)
163  KALDI_ERR << "Invalid size " << size;
164  inputs.resize(size);
165  for (int32 i = 0; i < size; i++)
166  inputs[i].Read(is, binary);
167  ExpectToken(is, binary, "<NumOutputs>");
168  ReadBasicType(is, binary, &size);
169  if (size < 1 || size > 1000000)
170  KALDI_ERR << "Invalid size " << size;
171  outputs.resize(size);
172  for (int32 i = 0; i < size; i++)
173  outputs[i].Read(is, binary);
174  ExpectToken(is, binary, "</Nnet3ChainEg>");
175 }
176 
178  inputs.swap(other->inputs);
179  outputs.swap(other->outputs);
180 }
181 
183  std::vector<NnetIo>::iterator iter = inputs.begin(), end = inputs.end();
184  // calling features.Compress() will do nothing if they are sparse or already
185  // compressed.
186  for (; iter != end; ++iter) iter->features.Compress();
187 }
188 
190  inputs(other.inputs), outputs(other.outputs) { }
191 
192 
193 // called from MergeChainExamplesInternal, this function merges the Supervision
194 // objects into one. Requires (and checks) that they all have the same name.
195 static void MergeSupervision(
196  const std::vector<const NnetChainSupervision*> &inputs,
197  NnetChainSupervision *output) {
198  int32 num_inputs = inputs.size(),
199  num_indexes = 0;
200  for (int32 n = 0; n < num_inputs; n++) {
201  KALDI_ASSERT(inputs[n]->name == inputs[0]->name);
202  num_indexes += inputs[n]->indexes.size();
203  }
204  output->name = inputs[0]->name;
205  std::vector<const chain::Supervision*> input_supervision;
206  input_supervision.reserve(inputs.size());
207  for (int32 n = 0; n < num_inputs; n++)
208  input_supervision.push_back(&(inputs[n]->supervision));
209  chain::Supervision output_supervision;
210  MergeSupervision(input_supervision,
211  &output_supervision);
212  output->supervision.Swap(&output_supervision);
213 
214  output->indexes.clear();
215  output->indexes.reserve(num_indexes);
216  for (int32 n = 0; n < num_inputs; n++) {
217  const std::vector<Index> &src_indexes = inputs[n]->indexes;
218  int32 cur_size = output->indexes.size();
219  output->indexes.insert(output->indexes.end(),
220  src_indexes.begin(), src_indexes.end());
221  std::vector<Index>::iterator iter = output->indexes.begin() + cur_size,
222  end = output->indexes.end();
223  // change the 'n' index to correspond to the index into 'input'.
224  // Each example gets a different 'n' value, starting from 0.
225  for (; iter != end; ++iter) {
226  KALDI_ASSERT(iter->n == 0 && "Merging already-merged chain egs");
227  iter->n = n;
228  }
229  }
230  KALDI_ASSERT(output->indexes.size() == num_indexes);
231  // OK, at this point the 'indexes' will be in the wrong order,
232  // because they should be first sorted by 't' and next by 'n'.
233  // 'sort' will fix this, due to the operator < on type Index.
234  std::sort(output->indexes.begin(), output->indexes.end());
235 
236  // merge the deriv_weights.
237  if (inputs[0]->deriv_weights.Dim() != 0) {
238  int32 frames_per_sequence = inputs[0]->deriv_weights.Dim();
239  output->deriv_weights.Resize(output->indexes.size(), kUndefined);
240  KALDI_ASSERT(output->deriv_weights.Dim() ==
241  frames_per_sequence * num_inputs);
242  for (int32 n = 0; n < num_inputs; n++) {
243  const Vector<BaseFloat> &src_deriv_weights = inputs[n]->deriv_weights;
244  KALDI_ASSERT(src_deriv_weights.Dim() == frames_per_sequence);
245  // the ordering of the deriv_weights corresponds to the ordering of the
246  // Indexes, where the time dimension has the greater stride.
247  for (int32 t = 0; t < frames_per_sequence; t++) {
248  output->deriv_weights(t * num_inputs + n) = src_deriv_weights(t);
249  }
250  }
251  }
252  output->CheckDim();
253 }
254 
255 
256 void MergeChainExamples(bool compress,
257  std::vector<NnetChainExample> *input,
258  NnetChainExample *output) {
259  int32 num_examples = input->size();
260  KALDI_ASSERT(num_examples > 0);
261  // we temporarily make the input-features in 'input' look like regular NnetExamples,
262  // so that we can recycle the MergeExamples() function.
263  std::vector<NnetExample> eg_inputs(num_examples);
264  for (int32 i = 0; i < num_examples; i++)
265  eg_inputs[i].io.swap((*input)[i].inputs);
266  NnetExample eg_output;
267  MergeExamples(eg_inputs, compress, &eg_output);
268  // swap the inputs back so that they are not really changed.
269  for (int32 i = 0; i < num_examples; i++)
270  eg_inputs[i].io.swap((*input)[i].inputs);
271  // write to 'output->inputs'
272  eg_output.io.swap(output->inputs);
273 
274  // Now deal with the chain-supervision 'outputs'. There will
275  // normally be just one of these, with name "output", but we
276  // handle the more general case.
277  int32 num_output_names = (*input)[0].outputs.size();
278  output->outputs.resize(num_output_names);
279  for (int32 i = 0; i < num_output_names; i++) {
280  std::vector<const NnetChainSupervision*> to_merge(num_examples);
281  for (int32 j = 0; j < num_examples; j++) {
282  KALDI_ASSERT((*input)[j].outputs.size() == num_output_names);
283  to_merge[j] = &((*input)[j].outputs[i]);
284  }
285  MergeSupervision(to_merge,
286  &(output->outputs[i]));
287  }
288 }
289 
291  const NnetChainExample &eg,
292  bool need_model_derivative,
293  bool store_component_stats,
294  bool use_xent_regularization,
295  bool use_xent_derivative,
296  ComputationRequest *request) {
297  request->inputs.clear();
298  request->inputs.reserve(eg.inputs.size());
299  request->outputs.clear();
300  request->outputs.reserve(eg.outputs.size() * 2);
301  request->need_model_derivative = need_model_derivative;
302  request->store_component_stats = store_component_stats;
303  for (size_t i = 0; i < eg.inputs.size(); i++) {
304  const NnetIo &io = eg.inputs[i];
305  const std::string &name = io.name;
306  int32 node_index = nnet.GetNodeIndex(name);
307  if (node_index == -1 ||
308  !nnet.IsInputNode(node_index))
309  KALDI_ERR << "Nnet example has input named '" << name
310  << "', but no such input node is in the network.";
311 
312  request->inputs.resize(request->inputs.size() + 1);
313  IoSpecification &io_spec = request->inputs.back();
314  io_spec.name = name;
315  io_spec.indexes = io.indexes;
316  io_spec.has_deriv = false;
317  }
318  for (size_t i = 0; i < eg.outputs.size(); i++) {
319  // there will normally be exactly one output , named "output"
320  const NnetChainSupervision &sup = eg.outputs[i];
321  const std::string &name = sup.name;
322  int32 node_index = nnet.GetNodeIndex(name);
323  if (node_index == -1 &&
324  !nnet.IsOutputNode(node_index))
325  KALDI_ERR << "Nnet example has output named '" << name
326  << "', but no such output node is in the network.";
327  request->outputs.resize(request->outputs.size() + 1);
328  IoSpecification &io_spec = request->outputs.back();
329  io_spec.name = name;
330  io_spec.indexes = sup.indexes;
331  io_spec.has_deriv = need_model_derivative;
332 
333  if (use_xent_regularization) {
334  size_t cur_size = request->outputs.size();
335  request->outputs.resize(cur_size + 1);
336  IoSpecification &io_spec = request->outputs[cur_size - 1],
337  &io_spec_xent = request->outputs[cur_size];
338  // the IoSpecification for the -xent output is the same
339  // as for the regular output, except for its name which has
340  // the -xent suffix (and the has_deriv member may differ).
341  io_spec_xent = io_spec;
342  io_spec_xent.name = name + "-xent";
343  io_spec_xent.has_deriv = use_xent_derivative;
344  }
345  }
346  // check to see if something went wrong.
347  if (request->inputs.empty())
348  KALDI_ERR << "No inputs in computation request.";
349  if (request->outputs.empty())
350  KALDI_ERR << "No outputs in computation request.";
351 }
352 
353 void ShiftChainExampleTimes(int32 frame_shift,
354  const std::vector<std::string> &exclude_names,
355  NnetChainExample *eg) {
356  std::vector<NnetIo>::iterator input_iter = eg->inputs.begin(),
357  input_end = eg->inputs.end();
358  for (; input_iter != input_end; ++input_iter) {
359  bool must_exclude = false;
360  std::vector<std::string>::const_iterator exclude_iter = exclude_names.begin(),
361  exclude_end = exclude_names.end();
362  for (; exclude_iter != exclude_end; ++exclude_iter)
363  if (input_iter->name == *exclude_iter)
364  must_exclude = true;
365  if (!must_exclude) {
366  std::vector<Index>::iterator indexes_iter = input_iter->indexes.begin(),
367  indexes_end = input_iter->indexes.end();
368  for (; indexes_iter != indexes_end; ++indexes_iter)
369  indexes_iter->t += frame_shift;
370  }
371  }
372  // note: we'll normally choose a small enough shift that the output-data
373  // shift will be zero after dividing by frame_subsampling_factor
374  // (e.g. frame_subsampling_factor == 3 and shift = 0 or 1.
375  std::vector<NnetChainSupervision>::iterator
376  sup_iter = eg->outputs.begin(),
377  sup_end = eg->outputs.end();
378  for (; sup_iter != sup_end; ++sup_iter) {
379  std::vector<Index> &indexes = sup_iter->indexes;
380  KALDI_ASSERT(indexes.size() >= 2 && indexes[0].n == indexes[1].n &&
381  indexes[0].x == indexes[1].x);
382  int32 frame_subsampling_factor = indexes[1].t - indexes[0].t;
383  KALDI_ASSERT(frame_subsampling_factor > 0);
384 
385  // We need to shift by a multiple of frame_subsampling_factor.
386  // Round to the closest multiple.
387  int32 supervision_frame_shift =
388  frame_subsampling_factor *
389  std::floor(0.5 + (frame_shift * 1.0 / frame_subsampling_factor));
390  if (supervision_frame_shift == 0)
391  continue;
392  std::vector<Index>::iterator indexes_iter = indexes.begin(),
393  indexes_end = indexes.end();
394  for (; indexes_iter != indexes_end; ++indexes_iter)
395  indexes_iter->t += supervision_frame_shift;
396  }
397 }
398 
399 
401  const NnetChainExample &eg) const noexcept {
402  // these numbers were chosen at random from a list of primes.
403  NnetIoStructureHasher io_hasher;
404  size_t size = eg.inputs.size(), ans = size * 35099;
405  for (size_t i = 0; i < size; i++)
406  ans = ans * 19157 + io_hasher(eg.inputs[i]);
407  for (size_t i = 0; i < eg.outputs.size(); i++) {
408  const NnetChainSupervision &sup = eg.outputs[i];
409  StringHasher string_hasher;
410  IndexVectorHasher indexes_hasher;
411  ans = ans * 17957 +
412  string_hasher(sup.name) + indexes_hasher(sup.indexes);
413  }
414  return ans;
415 }
416 
418  const NnetChainExample &a,
419  const NnetChainExample &b) const {
420  NnetIoStructureCompare io_compare;
421  if (a.inputs.size() != b.inputs.size() ||
422  a.outputs.size() != b.outputs.size())
423  return false;
424  size_t size = a.inputs.size();
425  for (size_t i = 0; i < size; i++)
426  if (!io_compare(a.inputs[i], b.inputs[i]))
427  return false;
428  size = a.outputs.size();
429  for (size_t i = 0; i < size; i++)
430  if (a.outputs[i].name != b.outputs[i].name ||
431  a.outputs[i].indexes != b.outputs[i].indexes)
432  return false;
433  return true;
434 }
435 
436 
438  int32 ans = 0;
439  for (size_t i = 0; i < a.inputs.size(); i++) {
440  int32 s = a.inputs[i].indexes.size();
441  if (s > ans)
442  ans = s;
443  }
444  for (size_t i = 0; i < a.outputs.size(); i++) {
445  int32 s = a.outputs[i].indexes.size();
446  if (s > ans)
447  ans = s;
448  }
449  return ans;
450 }
451 
452 
454  NnetChainExampleWriter *writer):
455  finished_(false), num_egs_written_(0),
456  config_(config), writer_(writer) { }
457 
458 
461  // If an eg with the same structure as 'eg' is already a key in the
462  // map, it won't be replaced, but if it's new it will be made
463  // the key. Also we remove the key before making the vector empty.
464  // This way we ensure that the eg in the key is always the first
465  // element of the vector.
466  std::vector<NnetChainExample*> &vec = eg_to_egs_[eg];
467  vec.push_back(eg);
468  int32 eg_size = GetNnetChainExampleSize(*eg),
469  num_available = vec.size();
470  bool input_ended = false;
471  int32 minibatch_size = config_.MinibatchSize(eg_size, num_available,
472  input_ended);
473  if (minibatch_size != 0) { // we need to write out a merged eg.
474  KALDI_ASSERT(minibatch_size == num_available);
475 
476  std::vector<NnetChainExample*> vec_copy(vec);
477  eg_to_egs_.erase(eg);
478 
479  // MergeChainExamples() expects a vector of NnetChainExample, not of pointers,
480  // so use swap to create that without doing any real work.
481  std::vector<NnetChainExample> egs_to_merge(minibatch_size);
482  for (int32 i = 0; i < minibatch_size; i++) {
483  egs_to_merge[i].Swap(vec_copy[i]);
484  delete vec_copy[i]; // we owned those pointers.
485  }
486  WriteMinibatch(&egs_to_merge);
487  }
488 }
489 
491  std::vector<NnetChainExample> *egs) {
492  KALDI_ASSERT(!egs->empty());
493  int32 eg_size = GetNnetChainExampleSize((*egs)[0]);
495  size_t structure_hash = eg_hasher((*egs)[0]);
496  int32 minibatch_size = egs->size();
497  stats_.WroteExample(eg_size, structure_hash, minibatch_size);
498  NnetChainExample merged_eg;
499  MergeChainExamples(config_.compress, egs, &merged_eg);
500  std::ostringstream key;
501  key << "merged-" << (num_egs_written_++) << "-" << minibatch_size;
502  writer_->Write(key.str(), merged_eg);
503 }
504 
506  if (finished_) return; // already finished.
507  finished_ = true;
508 
509  // we'll convert the map eg_to_egs_ to a vector of vectors to avoid
510  // iterator invalidation problems.
511  std::vector<std::vector<NnetChainExample*> > all_egs;
512  all_egs.reserve(eg_to_egs_.size());
513 
514  MapType::iterator iter = eg_to_egs_.begin(), end = eg_to_egs_.end();
515  for (; iter != end; ++iter)
516  all_egs.push_back(iter->second);
517  eg_to_egs_.clear();
518 
519  for (size_t i = 0; i < all_egs.size(); i++) {
520  int32 minibatch_size;
521  std::vector<NnetChainExample*> &vec = all_egs[i];
522  KALDI_ASSERT(!vec.empty());
523  int32 eg_size = GetNnetChainExampleSize(*(vec[0]));
524  bool input_ended = true;
525  while (!vec.empty() &&
526  (minibatch_size = config_.MinibatchSize(eg_size, vec.size(),
527  input_ended)) != 0) {
528  // MergeChainExamples() expects a vector of
529  // NnetChainExample, not of pointers, so use swap to create that
530  // without doing any real work.
531  std::vector<NnetChainExample> egs_to_merge(minibatch_size);
532  for (int32 i = 0; i < minibatch_size; i++) {
533  egs_to_merge[i].Swap(vec[i]);
534  delete vec[i]; // we owned those pointers.
535  }
536  vec.erase(vec.begin(), vec.begin() + minibatch_size);
537  WriteMinibatch(&egs_to_merge);
538  }
539  if (!vec.empty()) {
540  int32 eg_size = GetNnetChainExampleSize(*(vec[0]));
542  size_t structure_hash = eg_hasher(*(vec[0]));
543  int32 num_discarded = vec.size();
544  stats_.DiscardedExamples(eg_size, structure_hash, num_discarded);
545  for (int32 i = 0; i < num_discarded; i++)
546  delete vec[i];
547  vec.clear();
548  }
549  }
550  stats_.PrintStats();
551 }
552 
553 
554 
555 } // namespace nnet3
556 } // namespace kaldi
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void WriteIndexVector(std::ostream &os, bool binary, const std::vector< Index > &vec)
Definition: nnet-common.cc:126
Vector< BaseFloat > deriv_weights
This is a vector of per-frame weights, required to be between 0 and 1, that is applied to the derivat...
void DiscardedExamples(int32 example_size, size_t structure_hash, int32 num_discarded)
Users call this function to inform this class that after processing all the data, for examples of ori...
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
chain::Supervision supervision
The supervision object, containing the FST.
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
void MergeChainExamples(bool compress, std::vector< NnetChainExample > *input, NnetChainExample *output)
This function merges a list of NnetChainExample objects into a single one– intended to be used when ...
void ShiftChainExampleTimes(int32 frame_shift, const std::vector< std::string > &exclude_names, NnetChainExample *eg)
Shifts the time-index t of everything in the input of "eg" by adding "t_offset" to all "t" values– b...
bool IsInputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kInput.
Definition: nnet-nnet.cc:120
static void MergeSupervision(const std::vector< const NnetChainSupervision *> &inputs, NnetChainSupervision *output)
int32 MinibatchSize(int32 size_of_eg, int32 num_available_egs, bool input_ended) const
This function tells you what minibatch size should be used for this eg.
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
std::vector< IoSpecification > inputs
std::vector< Index > indexes
"indexes" is a vector the same length as features.NumRows(), explaining the meaning of each row of th...
Definition: nnet-example.h:42
void Write(std::ostream &os, bool binary) const
void Swap(NnetChainSupervision *other)
A hashing function object for strings.
Definition: stl-utils.h:248
std::string name
the name of the output in the neural net; in simple setups it will just be "output".
void Write(const std::string &key, const T &value) const
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
int32 GetNnetChainExampleSize(const NnetChainExample &a)
bool operator()(const NnetChainExample &a, const NnetChainExample &b) const
void PrintStats() const
Calling this will cause a log message with information about the examples to be printed.
This hashing object hashes just the structural aspects of the NnetExample without looking at the valu...
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
std::vector< NnetChainSupervision > outputs
&#39;outputs&#39; contains the chain output supervision.
ChainExampleMerger(const ExampleMergingConfig &config, NnetChainExampleWriter *writer)
struct rnnlm::@11::@12 n
void Read(std::istream &is, bool binary)
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
void AcceptExample(NnetChainExample *a)
size_t operator()(const NnetChainExample &eg) const noexcept
void Swap(NnetChainExample *other)
#define KALDI_ERR
Definition: kaldi-error.h:147
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void ReadIndexVector(std::istream &is, bool binary, std::vector< Index > *vec)
Definition: nnet-common.cc:143
void Read(std::istream &is, bool binary)
void WroteExample(int32 example_size, size_t structure_hash, int32 minibatch_size)
Users call this function to inform this class that one minibatch has been written aggregating &#39;miniba...
const ExampleMergingConfig & config_
NnetChainExampleWriter * writer_
bool operator==(const NnetChainSupervision &other) const
std::vector< Index > indexes
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< IoSpecification > outputs
This comparison object compares just the structural aspects of the NnetIo object (name, indexes, feature dimension) without looking at the value of features.
Definition: nnet-example.h:101
std::vector< Index > indexes
The indexes that the output corresponds to.
This hashing object hashes just the structural aspects of the NnetIo object (name, indexes, feature dimension) without looking at the value of features.
Definition: nnet-example.h:94
void ReadVectorAsChar(std::istream &is, bool binary, Vector< BaseFloat > *vec)
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
Definition: nnet-example.h:36
void WriteMinibatch(std::vector< NnetChainExample > *egs)
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void MergeExamples(const std::vector< NnetExample > &src, bool compress, NnetExample *merged_eg)
Merge a set of input examples into a single example (typically the size of "src" will be the minibatc...