DiscriminativeExampleMerger Class Reference

This class is responsible for arranging examples in groups that have the same strucure (i.e. More...

#include <nnet-discriminative-example.h>

Collaboration diagram for DiscriminativeExampleMerger:

Public Member Functions

 DiscriminativeExampleMerger (const ExampleMergingConfig &config, NnetDiscriminativeExampleWriter *writer)
 
void AcceptExample (NnetDiscriminativeExample *a)
 
void Finish ()
 
int32 ExitStatus ()
 
 ~DiscriminativeExampleMerger ()
 

Private Types

typedef unordered_map< NnetDiscriminativeExample *, std::vector< NnetDiscriminativeExample * >, NnetDiscriminativeExampleStructureHasher, NnetDiscriminativeExampleStructureCompareMapType
 

Private Member Functions

void WriteMinibatch (std::vector< NnetDiscriminativeExample > *egs)
 

Private Attributes

bool finished_
 
int32 num_egs_written_
 
const ExampleMergingConfigconfig_
 
NnetDiscriminativeExampleWriterwriter_
 
ExampleMergingStats stats_
 
MapType eg_to_egs_
 

Detailed Description

This class is responsible for arranging examples in groups that have the same strucure (i.e.

the same input and output indexes), and outputting them in suitable minibatches as defined by ExampleMergingConfig.

Definition at line 228 of file nnet-discriminative-example.h.

Member Typedef Documentation

◆ MapType

Constructor & Destructor Documentation

◆ DiscriminativeExampleMerger()

◆ ~DiscriminativeExampleMerger()

Member Function Documentation

◆ AcceptExample()

void AcceptExample ( NnetDiscriminativeExample a)

Definition at line 452 of file nnet-discriminative-example.cc.

References DiscriminativeExampleMerger::config_, DiscriminativeExampleMerger::eg_to_egs_, DiscriminativeExampleMerger::finished_, kaldi::nnet3::GetNnetDiscriminativeExampleSize(), rnnlm::i, KALDI_ASSERT, ExampleMergingConfig::MinibatchSize(), and DiscriminativeExampleMerger::WriteMinibatch().

Referenced by main().

452  {
454  // If an eg with the same structure as 'eg' is already a key in the
455  // map, it won't be replaced, but if it's new it will be made
456  // the key. Also we remove the key before making the vector empty.
457  // This way we ensure that the eg in the key is always the first
458  // element of the vector.
459  std::vector<NnetDiscriminativeExample*> &vec = eg_to_egs_[eg];
460  vec.push_back(eg);
462  num_available = vec.size();
463  bool input_ended = false;
464  int32 minibatch_size = config_.MinibatchSize(eg_size, num_available,
465  input_ended);
466  if (minibatch_size != 0) { // we need to write out a merged eg.
467  KALDI_ASSERT(minibatch_size == num_available);
468 
469  std::vector<NnetDiscriminativeExample*> vec_copy(vec);
470  eg_to_egs_.erase(eg);
471 
472  // MergeDiscriminativeExamples() expects a vector of NnetDiscriminativeExample, not of pointers,
473  // so use swap to create that without doing any real work.
474  std::vector<NnetDiscriminativeExample> egs_to_merge(minibatch_size);
475  for (int32 i = 0; i < minibatch_size; i++) {
476  egs_to_merge[i].Swap(vec_copy[i]);
477  delete vec_copy[i]; // we owned those pointers.
478  }
479  WriteMinibatch(&egs_to_merge);
480  }
481 }
int32 MinibatchSize(int32 size_of_eg, int32 num_available_egs, bool input_ended) const
This function tells you what minibatch size should be used for this eg.
kaldi::int32 int32
void WriteMinibatch(std::vector< NnetDiscriminativeExample > *egs)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNnetDiscriminativeExampleSize(const NnetDiscriminativeExample &a)

◆ ExitStatus()

int32 ExitStatus ( )
inline

◆ Finish()

void Finish ( )

Definition at line 498 of file nnet-discriminative-example.cc.

References DiscriminativeExampleMerger::config_, ExampleMergingStats::DiscardedExamples(), DiscriminativeExampleMerger::eg_to_egs_, DiscriminativeExampleMerger::finished_, kaldi::nnet3::GetNnetDiscriminativeExampleSize(), rnnlm::i, KALDI_ASSERT, ExampleMergingConfig::MinibatchSize(), ExampleMergingStats::PrintStats(), DiscriminativeExampleMerger::stats_, and DiscriminativeExampleMerger::WriteMinibatch().

Referenced by main().

498  {
499  if (finished_) return; // already finished.
500  finished_ = true;
501 
502  // we'll convert the map eg_to_egs_ to a vector of vectors to avoid
503  // iterator invalidation problems.
504  std::vector<std::vector<NnetDiscriminativeExample*> > all_egs;
505  all_egs.reserve(eg_to_egs_.size());
506 
507  MapType::iterator iter = eg_to_egs_.begin(), end = eg_to_egs_.end();
508  for (; iter != end; ++iter)
509  all_egs.push_back(iter->second);
510  eg_to_egs_.clear();
511 
512  for (size_t i = 0; i < all_egs.size(); i++) {
513  int32 minibatch_size;
514  std::vector<NnetDiscriminativeExample*> &vec = all_egs[i];
515  KALDI_ASSERT(!vec.empty());
516  int32 eg_size = GetNnetDiscriminativeExampleSize(*(vec[0]));
517  bool input_ended = true;
518  while (!vec.empty() &&
519  (minibatch_size = config_.MinibatchSize(eg_size, vec.size(),
520  input_ended)) != 0) {
521  // MergeDiscriminativeExamples() expects a vector of
522  // NnetDiscriminativeExample, not of pointers, so use swap to create that
523  // without doing any real work.
524  std::vector<NnetDiscriminativeExample> egs_to_merge(minibatch_size);
525  for (int32 i = 0; i < minibatch_size; i++) {
526  egs_to_merge[i].Swap(vec[i]);
527  delete vec[i]; // we owned those pointers.
528  }
529  vec.erase(vec.begin(), vec.begin() + minibatch_size);
530  WriteMinibatch(&egs_to_merge);
531  }
532  if (!vec.empty()) {
533  int32 eg_size = GetNnetDiscriminativeExampleSize(*(vec[0]));
534  NnetDiscriminativeExampleStructureHasher eg_hasher;
535  size_t structure_hash = eg_hasher(*(vec[0]));
536  int32 num_discarded = vec.size();
537  stats_.DiscardedExamples(eg_size, structure_hash, num_discarded);
538  for (int32 i = 0; i < num_discarded; i++)
539  delete vec[i];
540  vec.clear();
541  }
542  }
543  stats_.PrintStats();
544 }
void DiscardedExamples(int32 example_size, size_t structure_hash, int32 num_discarded)
Users call this function to inform this class that after processing all the data, for examples of ori...
int32 MinibatchSize(int32 size_of_eg, int32 num_available_egs, bool input_ended) const
This function tells you what minibatch size should be used for this eg.
kaldi::int32 int32
void PrintStats() const
Calling this will cause a log message with information about the examples to be printed.
void WriteMinibatch(std::vector< NnetDiscriminativeExample > *egs)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNnetDiscriminativeExampleSize(const NnetDiscriminativeExample &a)

◆ WriteMinibatch()

void WriteMinibatch ( std::vector< NnetDiscriminativeExample > *  egs)
private

Definition at line 483 of file nnet-discriminative-example.cc.

References ExampleMergingConfig::compress, DiscriminativeExampleMerger::config_, kaldi::nnet3::GetNnetDiscriminativeExampleSize(), KALDI_ASSERT, kaldi::nnet3::MergeDiscriminativeExamples(), DiscriminativeExampleMerger::num_egs_written_, DiscriminativeExampleMerger::stats_, TableWriter< Holder >::Write(), DiscriminativeExampleMerger::writer_, and ExampleMergingStats::WroteExample().

Referenced by DiscriminativeExampleMerger::AcceptExample(), and DiscriminativeExampleMerger::Finish().

484  {
485  KALDI_ASSERT(!egs->empty());
486  int32 eg_size = GetNnetDiscriminativeExampleSize((*egs)[0]);
487  NnetDiscriminativeExampleStructureHasher eg_hasher;
488  size_t structure_hash = eg_hasher((*egs)[0]);
489  int32 minibatch_size = egs->size();
490  stats_.WroteExample(eg_size, structure_hash, minibatch_size);
491  NnetDiscriminativeExample merged_eg;
492  MergeDiscriminativeExamples(config_.compress, egs, &merged_eg);
493  std::ostringstream key;
494  key << "merged-" << (num_egs_written_++) << "-" << minibatch_size;
495  writer_->Write(key.str(), merged_eg);
496 }
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void WroteExample(int32 example_size, size_t structure_hash, int32 minibatch_size)
Users call this function to inform this class that one minibatch has been written aggregating &#39;miniba...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void MergeDiscriminativeExamples(bool compress, std::vector< NnetDiscriminativeExample > *input, NnetDiscriminativeExample *output)
int32 GetNnetDiscriminativeExampleSize(const NnetDiscriminativeExample &a)

Member Data Documentation

◆ config_

◆ eg_to_egs_

◆ finished_

◆ num_egs_written_

int32 num_egs_written_
private

◆ stats_

◆ writer_


The documentation for this class was generated from the following files: