nnet-discriminative-example.cc
Go to the documentation of this file.
1 // nnet3/nnet-discriminative-example.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <cmath>
23 
24 namespace kaldi {
25 namespace nnet3 {
26 using std::string;
27 
28 void NnetDiscriminativeSupervision::Write(std::ostream &os, bool binary) const {
29  CheckDim();
30  WriteToken(os, binary, "<NnetDiscriminativeSup>");
31  WriteToken(os, binary, name);
32  WriteIndexVector(os, binary, indexes);
33  supervision.Write(os, binary);
34  WriteToken(os, binary, "<DW>"); // for DerivWeights. Want to save space.
35  WriteVectorAsChar(os, binary, deriv_weights);
36  WriteToken(os, binary, "</NnetDiscriminativeSup>");
37 }
38 
40  return name == other.name && indexes == other.indexes &&
41  supervision == other.supervision &&
42  deriv_weights.ApproxEqual(other.deriv_weights);
43 }
44 
45 void NnetDiscriminativeSupervision::Read(std::istream &is, bool binary) {
46  ExpectToken(is, binary, "<NnetDiscriminativeSup>");
47  ReadToken(is, binary, &name);
48  ReadIndexVector(is, binary, &indexes);
49  supervision.Read(is, binary);
50  ExpectToken(is, binary, "<DW>");
51  ReadVectorAsChar(is, binary, &deriv_weights);
52  ExpectToken(is, binary, "</NnetDiscriminativeSup>");
53  CheckDim();
54 }
55 
56 
58  if (supervision.frames_per_sequence == -1) {
59  // this object has not been set up.
60  KALDI_ASSERT(indexes.empty());
61  return;
62  }
66  int32 first_frame = indexes[0].t,
67  frame_skip = indexes[supervision.num_sequences].t - first_frame,
68  num_sequences = supervision.num_sequences,
69  frames_per_sequence = supervision.frames_per_sequence;
70  int32 k = 0;
71  for (int32 i = 0; i < frames_per_sequence; i++) {
72  for (int32 j = 0; j < num_sequences; j++,k++) {
73  int32 n = j, t = i * frame_skip + first_frame, x = 0;
74  Index index(n, t, x);
75  KALDI_ASSERT(indexes[k] == index);
76  }
77  }
78  if (deriv_weights.Dim() != 0) {
79  KALDI_ASSERT(deriv_weights.Dim() == indexes.size());
80  KALDI_ASSERT(deriv_weights.Min() >= 0.0 &&
81  deriv_weights.Max() <= 1.0);
82  }
83 }
84 
86  name(other.name),
87  indexes(other.indexes),
88  supervision(other.supervision),
90 
92  const std::string &name,
95  int32 first_frame,
96  int32 frame_skip):
97  name(name),
98  supervision(supervision),
99  deriv_weights(deriv_weights) {
100  // note: this will set the 'x' index to zero.
101  indexes.resize(supervision.num_sequences *
102  supervision.frames_per_sequence);
103  int32 k = 0, num_sequences = supervision.num_sequences,
104  frames_per_sequence = supervision.frames_per_sequence;
105  for (int32 i = 0; i < frames_per_sequence; i++) {
106  for (int32 j = 0; j < num_sequences; j++,k++) {
107  indexes[k].n = j;
108  indexes[k].t = i * frame_skip + first_frame;
109  }
110  }
111  KALDI_ASSERT(k == indexes.size());
112  CheckDim();
113 }
114 
116  name.swap(other->name);
117  indexes.swap(other->indexes);
118  supervision.Swap(&(other->supervision));
119  deriv_weights.Swap(&(other->deriv_weights));
120  if (RandInt(0, 5) == 0)
121  CheckDim();
122 }
123 
124 
125 void NnetDiscriminativeExample::Write(std::ostream &os, bool binary) const {
126  // Note: weight, label, input_frames and spk_info are members. This is a
127  // struct.
128  WriteToken(os, binary, "<Nnet3DiscriminativeEg>");
129  WriteToken(os, binary, "<NumInputs>");
130  int32 size = inputs.size();
131  WriteBasicType(os, binary, size);
132  KALDI_ASSERT(size > 0 && "Attempting to write NnetDiscriminativeExample with no inputs");
133  if (!binary) os << '\n';
134  for (int32 i = 0; i < size; i++) {
135  inputs[i].Write(os, binary);
136  if (!binary) os << '\n';
137  }
138  WriteToken(os, binary, "<NumOutputs>");
139  size = outputs.size();
140  WriteBasicType(os, binary, size);
141  KALDI_ASSERT(size > 0 && "Attempting to write NnetDiscriminativeExample with no outputs");
142  if (!binary) os << '\n';
143  for (int32 i = 0; i < size; i++) {
144  outputs[i].Write(os, binary);
145  if (!binary) os << '\n';
146  }
147  WriteToken(os, binary, "</Nnet3DiscriminativeEg>");
148 }
149 
150 void NnetDiscriminativeExample::Read(std::istream &is, bool binary) {
151  ExpectToken(is, binary, "<Nnet3DiscriminativeEg>");
152  ExpectToken(is, binary, "<NumInputs>");
153  int32 size;
154  ReadBasicType(is, binary, &size);
155  if (size < 1 || size > 1000000)
156  KALDI_ERR << "Invalid size " << size;
157  inputs.resize(size);
158  for (int32 i = 0; i < size; i++)
159  inputs[i].Read(is, binary);
160  ExpectToken(is, binary, "<NumOutputs>");
161  ReadBasicType(is, binary, &size);
162  if (size < 1 || size > 1000000)
163  KALDI_ERR << "Invalid size " << size;
164  outputs.resize(size);
165  for (int32 i = 0; i < size; i++)
166  outputs[i].Read(is, binary);
167  ExpectToken(is, binary, "</Nnet3DiscriminativeEg>");
168 }
169 
171  inputs.swap(other->inputs);
172  outputs.swap(other->outputs);
173 }
174 
176  std::vector<NnetIo>::iterator iter = inputs.begin(), end = inputs.end();
177  // calling features.Compress() will do nothing if they are sparse or already
178  // compressed.
179  for (; iter != end; ++iter) iter->features.Compress();
180 }
181 
183  inputs(other.inputs), outputs(other.outputs) { }
184 
186  const std::vector<const NnetDiscriminativeSupervision*> &inputs,
188  int32 num_inputs = inputs.size(),
189  num_indexes = 0;
190  for (int32 n = 0; n < num_inputs; n++) {
191  KALDI_ASSERT(inputs[n]->name == inputs[0]->name);
192  num_indexes += inputs[n]->indexes.size();
193  }
194  output->name = inputs[0]->name;
195  std::vector<const discriminative::DiscriminativeSupervision*> input_supervision;
196  input_supervision.reserve(inputs.size());
197  for (int32 n = 0; n < num_inputs; n++)
198  input_supervision.push_back(&(inputs[n]->supervision));
200  discriminative::MergeSupervision(input_supervision,
201  &output_supervision);
202  output->supervision.Swap(&(output_supervision));
203 
204  output->indexes.clear();
205  output->indexes.reserve(num_indexes);
206  for (int32 n = 0; n < num_inputs; n++) {
207  const std::vector<Index> &src_indexes = inputs[n]->indexes;
208  int32 cur_size = output->indexes.size();
209  output->indexes.insert(output->indexes.end(),
210  src_indexes.begin(), src_indexes.end());
211  std::vector<Index>::iterator iter = output->indexes.begin() + cur_size,
212  end = output->indexes.end();
213  // change the 'n' index to correspond to the index into 'input'.
214  // Each example gets a different 'n' value, starting from 0.
215  for (; iter != end; ++iter) {
216  KALDI_ASSERT(iter->n == 0 && "Merging already-merged discriminative egs");
217  iter->n = n;
218  }
219  }
220  KALDI_ASSERT(output->indexes.size() == num_indexes);
221  // OK, at this point the 'indexes' will be in the wrong order,
222  // because they should be first sorted by 't' and next by 'n'.
223  // 'sort' will fix this, due to the operator < on type Index.
224  // TODO: Is this required?
225  std::sort(output->indexes.begin(), output->indexes.end());
226 
227  // merge the deriv_weights.
228  if (inputs[0]->deriv_weights.Dim() != 0) {
229  int32 frames_per_sequence = inputs[0]->deriv_weights.Dim();
230  output->deriv_weights.Resize(output->indexes.size(), kUndefined);
231  KALDI_ASSERT(output->deriv_weights.Dim() ==
232  frames_per_sequence * num_inputs);
233  for (int32 n = 0; n < num_inputs; n++) {
234  const Vector<BaseFloat> &src_deriv_weights = inputs[n]->deriv_weights;
235  KALDI_ASSERT(src_deriv_weights.Dim() == frames_per_sequence);
236  // the ordering of the deriv_weights corresponds to the ordering of the
237  // Indexes, where the time dimension has the greater stride.
238  for (int32 t = 0; t < frames_per_sequence; t++) {
239  output->deriv_weights(t * num_inputs + n) = src_deriv_weights(t);
240  }
241  }
242  }
243  output->CheckDim();
244 }
245 
246 
248  bool compress,
249  std::vector<NnetDiscriminativeExample> *input,
250  NnetDiscriminativeExample *output) {
251  int32 num_examples = input->size();
252  KALDI_ASSERT(num_examples > 0);
253  // we temporarily make the input-features in 'input' look like regular
254  // NnetExamples, so that we can recycle the
255  // MergeExamples() function.
256  std::vector<NnetExample> eg_inputs(num_examples);
257  for (int32 i = 0; i < num_examples; i++)
258  eg_inputs[i].io.swap((*input)[i].inputs);
259  NnetExample eg_output;
260  MergeExamples(eg_inputs, compress, &eg_output);
261  // swap the inputs back so that they are not really changed.
262  for (int32 i = 0; i < num_examples; i++)
263  eg_inputs[i].io.swap((*input)[i].inputs);
264  // write to 'output->inputs'
265  eg_output.io.swap(output->inputs);
266 
267  // Now deal with the discriminative-supervision 'outputs'. There will
268  // normally be just one of these, with name "output", but we
269  // handle the more general case.
270  int32 num_output_names = (*input)[0].outputs.size();
271  output->outputs.resize(num_output_names);
272  for (int32 i = 0; i < num_output_names; i++) {
273  std::vector<const NnetDiscriminativeSupervision*> to_merge(num_examples);
274  for (int32 j = 0; j < num_examples; j++) {
275  KALDI_ASSERT((*input)[j].outputs.size() == num_output_names);
276  to_merge[j] = &((*input)[j].outputs[i]);
277  }
278  MergeSupervision(to_merge,
279  &(output->outputs[i]));
280  }
281 }
282 
283 
285  const NnetDiscriminativeExample &eg,
286  bool need_model_derivative,
287  bool store_component_stats,
288  bool use_xent_regularization,
289  bool use_xent_derivative,
290  ComputationRequest *request) {
291  request->inputs.clear();
292  request->inputs.reserve(eg.inputs.size());
293  request->outputs.clear();
294  request->outputs.reserve(eg.outputs.size());
295  request->need_model_derivative = need_model_derivative;
296  request->store_component_stats = store_component_stats;
297  for (size_t i = 0; i < eg.inputs.size(); i++) {
298  const NnetIo &io = eg.inputs[i];
299  const std::string &name = io.name;
300  int32 node_index = nnet.GetNodeIndex(name);
301  if (node_index == -1 &&
302  !nnet.IsInputNode(node_index))
303  KALDI_ERR << "Nnet example has input named '" << name
304  << "', but no such input node is in the network.";
305 
306  request->inputs.resize(request->inputs.size() + 1);
307  IoSpecification &io_spec = request->inputs.back();
308  io_spec.name = name;
309  io_spec.indexes = io.indexes;
310  io_spec.has_deriv = false;
311  }
312  for (size_t i = 0; i < eg.outputs.size(); i++) {
313  // there will normally be exactly one output , named "output"
314  const NnetDiscriminativeSupervision &sup = eg.outputs[i];
315  const std::string &name = sup.name;
316  int32 node_index = nnet.GetNodeIndex(name);
317  if (node_index == -1 &&
318  !nnet.IsOutputNode(node_index))
319  KALDI_ERR << "Nnet example has output named '" << name
320  << "', but no such output node is in the network.";
321  request->outputs.resize(request->outputs.size() + 1);
322  IoSpecification &io_spec = request->outputs.back();
323  io_spec.name = name;
324  io_spec.indexes = sup.indexes;
325  io_spec.has_deriv = need_model_derivative;
326 
327  if (use_xent_regularization) {
328  size_t cur_size = request->outputs.size();
329  request->outputs.resize(cur_size + 1);
330  IoSpecification &io_spec = request->outputs[cur_size - 1],
331  &io_spec_xent = request->outputs[cur_size];
332  // the IoSpecification for the -xent output is the same
333  // as for the regular output, except for its name which has
334  // the -xent suffix (and the has_deriv member may differ).
335  io_spec_xent = io_spec;
336  io_spec_xent.name = name + "-xent";
337  io_spec_xent.has_deriv = use_xent_derivative;
338  }
339  }
340  // check to see if something went wrong.
341  if (request->inputs.empty())
342  KALDI_ERR << "No inputs in computation request.";
343  if (request->outputs.empty())
344  KALDI_ERR << "No outputs in computation request.";
345 }
346 
348  const std::vector<std::string> &exclude_names,
350  std::vector<NnetIo>::iterator input_iter = eg->inputs.begin(),
351  input_end = eg->inputs.end();
352  for (; input_iter != input_end; ++input_iter) {
353  bool must_exclude = false;
354  std::vector<string>::const_iterator exclude_iter = exclude_names.begin(),
355  exclude_end = exclude_names.end();
356  for (; exclude_iter != exclude_end; ++exclude_iter)
357  if (input_iter->name == *exclude_iter)
358  must_exclude = true;
359  if (!must_exclude) {
360  std::vector<Index>::iterator indexes_iter = input_iter->indexes.begin(),
361  indexes_end = input_iter->indexes.end();
362  for (; indexes_iter != indexes_end; ++indexes_iter)
363  indexes_iter->t += frame_shift;
364  }
365  }
366  // note: we'll normally choose a small enough shift that the output-data
367  // shift will be zero after dividing by frame_subsampling_factor
368  // (e.g. frame_subsampling_factor == 3 and shift = 0 or 1.
369  std::vector<NnetDiscriminativeSupervision>::iterator
370  sup_iter = eg->outputs.begin(),
371  sup_end = eg->outputs.end();
372  for (; sup_iter != sup_end; ++sup_iter) {
373  std::vector<Index> &indexes = sup_iter->indexes;
374  KALDI_ASSERT(indexes.size() >= 2 && indexes[0].n == indexes[1].n &&
375  indexes[0].x == indexes[1].x);
376  int32 frame_subsampling_factor = indexes[1].t - indexes[0].t;
377  KALDI_ASSERT(frame_subsampling_factor > 0);
378 
379  // We need to shift by a multiple of frame_subsampling_factor.
380  // Round to the closest multiple.
381  int32 supervision_frame_shift =
382  frame_subsampling_factor *
383  std::floor(0.5 + (frame_shift * 1.0 / frame_subsampling_factor));
384  if (supervision_frame_shift == 0)
385  continue;
386  std::vector<Index>::iterator indexes_iter = indexes.begin(),
387  indexes_end = indexes.end();
388  for (; indexes_iter != indexes_end; ++indexes_iter)
389  indexes_iter->t += supervision_frame_shift;
390  }
391 }
392 
394  const NnetDiscriminativeExample &eg) const noexcept {
395  // these numbers were chosen at random from a list of primes.
396  NnetIoStructureHasher io_hasher;
397  size_t size = eg.inputs.size(), ans = size * 35099;
398  for (size_t i = 0; i < size; i++)
399  ans = ans * 19157 + io_hasher(eg.inputs[i]);
400  for (size_t i = 0; i < eg.outputs.size(); i++) {
401  const NnetDiscriminativeSupervision &sup = eg.outputs[i];
402  StringHasher string_hasher;
403  IndexVectorHasher indexes_hasher;
404  ans = ans * 17957 +
405  string_hasher(sup.name) + indexes_hasher(sup.indexes);
406  }
407  return ans;
408 }
409 
411  const NnetDiscriminativeExample &a,
412  const NnetDiscriminativeExample &b) const {
413  NnetIoStructureCompare io_compare;
414  if (a.inputs.size() != b.inputs.size() ||
415  a.outputs.size() != b.outputs.size())
416  return false;
417  size_t size = a.inputs.size();
418  for (size_t i = 0; i < size; i++)
419  if (!io_compare(a.inputs[i], b.inputs[i]))
420  return false;
421  size = a.outputs.size();
422  for (size_t i = 0; i < size; i++)
423  if (a.outputs[i].name != b.outputs[i].name ||
424  a.outputs[i].indexes != b.outputs[i].indexes)
425  return false;
426  return true;
427 }
428 
429 
431  int32 ans = 0;
432  for (size_t i = 0; i < a.inputs.size(); i++) {
433  int32 s = a.inputs[i].indexes.size();
434  if (s > ans)
435  ans = s;
436  }
437  for (size_t i = 0; i < a.outputs.size(); i++) {
438  int32 s = a.outputs[i].indexes.size();
439  if (s > ans)
440  ans = s;
441  }
442  return ans;
443 }
444 
445 
448  finished_(false), num_egs_written_(0),
449  config_(config), writer_(writer) { }
450 
451 
454  // If an eg with the same structure as 'eg' is already a key in the
455  // map, it won't be replaced, but if it's new it will be made
456  // the key. Also we remove the key before making the vector empty.
457  // This way we ensure that the eg in the key is always the first
458  // element of the vector.
459  std::vector<NnetDiscriminativeExample*> &vec = eg_to_egs_[eg];
460  vec.push_back(eg);
462  num_available = vec.size();
463  bool input_ended = false;
464  int32 minibatch_size = config_.MinibatchSize(eg_size, num_available,
465  input_ended);
466  if (minibatch_size != 0) { // we need to write out a merged eg.
467  KALDI_ASSERT(minibatch_size == num_available);
468 
469  std::vector<NnetDiscriminativeExample*> vec_copy(vec);
470  eg_to_egs_.erase(eg);
471 
472  // MergeDiscriminativeExamples() expects a vector of NnetDiscriminativeExample, not of pointers,
473  // so use swap to create that without doing any real work.
474  std::vector<NnetDiscriminativeExample> egs_to_merge(minibatch_size);
475  for (int32 i = 0; i < minibatch_size; i++) {
476  egs_to_merge[i].Swap(vec_copy[i]);
477  delete vec_copy[i]; // we owned those pointers.
478  }
479  WriteMinibatch(&egs_to_merge);
480  }
481 }
482 
484  std::vector<NnetDiscriminativeExample> *egs) {
485  KALDI_ASSERT(!egs->empty());
486  int32 eg_size = GetNnetDiscriminativeExampleSize((*egs)[0]);
488  size_t structure_hash = eg_hasher((*egs)[0]);
489  int32 minibatch_size = egs->size();
490  stats_.WroteExample(eg_size, structure_hash, minibatch_size);
491  NnetDiscriminativeExample merged_eg;
492  MergeDiscriminativeExamples(config_.compress, egs, &merged_eg);
493  std::ostringstream key;
494  key << "merged-" << (num_egs_written_++) << "-" << minibatch_size;
495  writer_->Write(key.str(), merged_eg);
496 }
497 
499  if (finished_) return; // already finished.
500  finished_ = true;
501 
502  // we'll convert the map eg_to_egs_ to a vector of vectors to avoid
503  // iterator invalidation problems.
504  std::vector<std::vector<NnetDiscriminativeExample*> > all_egs;
505  all_egs.reserve(eg_to_egs_.size());
506 
507  MapType::iterator iter = eg_to_egs_.begin(), end = eg_to_egs_.end();
508  for (; iter != end; ++iter)
509  all_egs.push_back(iter->second);
510  eg_to_egs_.clear();
511 
512  for (size_t i = 0; i < all_egs.size(); i++) {
513  int32 minibatch_size;
514  std::vector<NnetDiscriminativeExample*> &vec = all_egs[i];
515  KALDI_ASSERT(!vec.empty());
516  int32 eg_size = GetNnetDiscriminativeExampleSize(*(vec[0]));
517  bool input_ended = true;
518  while (!vec.empty() &&
519  (minibatch_size = config_.MinibatchSize(eg_size, vec.size(),
520  input_ended)) != 0) {
521  // MergeDiscriminativeExamples() expects a vector of
522  // NnetDiscriminativeExample, not of pointers, so use swap to create that
523  // without doing any real work.
524  std::vector<NnetDiscriminativeExample> egs_to_merge(minibatch_size);
525  for (int32 i = 0; i < minibatch_size; i++) {
526  egs_to_merge[i].Swap(vec[i]);
527  delete vec[i]; // we owned those pointers.
528  }
529  vec.erase(vec.begin(), vec.begin() + minibatch_size);
530  WriteMinibatch(&egs_to_merge);
531  }
532  if (!vec.empty()) {
533  int32 eg_size = GetNnetDiscriminativeExampleSize(*(vec[0]));
535  size_t structure_hash = eg_hasher(*(vec[0]));
536  int32 num_discarded = vec.size();
537  stats_.DiscardedExamples(eg_size, structure_hash, num_discarded);
538  for (int32 i = 0; i < num_discarded; i++)
539  delete vec[i];
540  vec.clear();
541  }
542  }
543  stats_.PrintStats();
544 }
545 
546 
547 
548 } // namespace nnet3
549 } // namespace kaldi
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
void ShiftDiscriminativeExampleTimes(int32 frame_shift, const std::vector< std::string > &exclude_names, NnetDiscriminativeExample *eg)
Shifts the time-index t of everything in the input of "eg" by adding "t_offset" to all "t" values– b...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void WriteIndexVector(std::ostream &os, bool binary, const std::vector< Index > &vec)
Definition: nnet-common.cc:126
void DiscardedExamples(int32 example_size, size_t structure_hash, int32 num_discarded)
Users call this function to inform this class that after processing all the data, for examples of ori...
void AcceptExample(NnetDiscriminativeExample *a)
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...
void Write(std::ostream &os, bool binary) const
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
bool IsInputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kInput.
Definition: nnet-nnet.cc:120
static void MergeSupervision(const std::vector< const NnetChainSupervision *> &inputs, NnetChainSupervision *output)
int32 MinibatchSize(int32 size_of_eg, int32 num_available_egs, bool input_ended) const
This function tells you what minibatch size should be used for this eg.
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
std::vector< IoSpecification > inputs
void MergeSupervision(const std::vector< const DiscriminativeSupervision *> &input, DiscriminativeSupervision *output_supervision)
This function appends a list of supervision objects to create what will usually be a single such obje...
std::vector< Index > indexes
"indexes" is a vector the same length as features.NumRows(), explaining the meaning of each row of th...
Definition: nnet-example.h:42
A hashing function object for strings.
Definition: stl-utils.h:248
bool operator()(const NnetDiscriminativeExample &a, const NnetDiscriminativeExample &b) const
This hashing object hashes just the structural aspects of the NnetExample without looking at the valu...
void Write(const std::string &key, const T &value) const
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
void PrintStats() const
Calling this will cause a log message with information about the examples to be printed.
void WriteVectorAsChar(std::ostream &os, bool binary, const VectorBase< BaseFloat > &vec)
void WriteMinibatch(std::vector< NnetDiscriminativeExample > *egs)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
size_t operator()(const NnetDiscriminativeExample &eg) const noexcept
struct rnnlm::@11::@12 n
void Write(std::ostream &os, bool binary) const
#define KALDI_ERR
Definition: kaldi-error.h:147
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Swap(NnetDiscriminativeSupervision *other)
void ReadIndexVector(std::istream &is, bool binary, std::vector< Index > *vec)
Definition: nnet-common.cc:143
void WroteExample(int32 example_size, size_t structure_hash, int32 minibatch_size)
Users call this function to inform this class that one minibatch has been written aggregating &#39;miniba...
DiscriminativeExampleMerger(const ExampleMergingConfig &config, NnetDiscriminativeExampleWriter *writer)
std::vector< Index > indexes
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Swap(NnetDiscriminativeExample *other)
std::vector< IoSpecification > outputs
This comparison object compares just the structural aspects of the NnetIo object (name, indexes, feature dimension) without looking at the value of features.
Definition: nnet-example.h:101
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
std::vector< NnetDiscriminativeSupervision > outputs
&#39;outputs&#39; contains the sequence output supervision.
This hashing object hashes just the structural aspects of the NnetIo object (name, indexes, feature dimension) without looking at the value of features.
Definition: nnet-example.h:94
void ReadVectorAsChar(std::istream &is, bool binary, Vector< BaseFloat > *vec)
void MergeDiscriminativeExamples(bool compress, std::vector< NnetDiscriminativeExample > *input, NnetDiscriminativeExample *output)
int32 GetNnetDiscriminativeExampleSize(const NnetDiscriminativeExample &a)
void GetDiscriminativeComputationRequest(const Nnet &nnet, const NnetDiscriminativeExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetDiscriminativeExample and produces a ComputationRequest.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
Definition: nnet-example.h:36
bool operator==(const NnetDiscriminativeSupervision &other) const
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
discriminative::DiscriminativeSupervision supervision
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
void Write(std::ostream &os, bool binary) const
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void MergeExamples(const std::vector< NnetExample > &src, bool compress, NnetExample *merged_eg)
Merge a set of input examples into a single example (typically the size of "src" will be the minibatc...