All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
ChainExampleMerger Class Reference

This class is responsible for arranging examples in groups that have the same strucure (i.e. More...

#include <nnet-chain-example.h>

Collaboration diagram for ChainExampleMerger:

Public Member Functions

 ChainExampleMerger (const ExampleMergingConfig &config, NnetChainExampleWriter *writer)
 
void AcceptExample (NnetChainExample *a)
 
void Finish ()
 
int32 ExitStatus ()
 
 ~ChainExampleMerger ()
 

Private Types

typedef unordered_map
< NnetChainExample
*, std::vector
< NnetChainExample * >
, NnetChainExampleStructureHasher,
NnetChainExampleStructureCompare
MapType
 

Private Member Functions

void WriteMinibatch (std::vector< NnetChainExample > *egs)
 

Private Attributes

bool finished_
 
int32 num_egs_written_
 
const ExampleMergingConfigconfig_
 
NnetChainExampleWriterwriter_
 
ExampleMergingStats stats_
 
MapType eg_to_egs_
 

Detailed Description

This class is responsible for arranging examples in groups that have the same strucure (i.e.

the same input and output indexes), and outputting them in suitable minibatches as defined by ExampleMergingConfig.

Definition at line 230 of file nnet-chain-example.h.

Member Typedef Documentation

Definition at line 268 of file nnet-chain-example.h.

Constructor & Destructor Documentation

Definition at line 456 of file nnet-chain-example.cc.

457  :
458  finished_(false), num_egs_written_(0),
459  config_(config), writer_(writer) { }
const ExampleMergingConfig & config_
NnetChainExampleWriter * writer_
~ChainExampleMerger ( )
inline

Definition at line 250 of file nnet-chain-example.h.

References ChainExampleMerger::Finish().

Member Function Documentation

void AcceptExample ( NnetChainExample a)

Definition at line 462 of file nnet-chain-example.cc.

References ChainExampleMerger::config_, ChainExampleMerger::eg_to_egs_, ChainExampleMerger::finished_, kaldi::nnet3::GetNnetChainExampleSize(), rnnlm::i, KALDI_ASSERT, ExampleMergingConfig::MinibatchSize(), and ChainExampleMerger::WriteMinibatch().

462  {
464  // If an eg with the same structure as 'eg' is already a key in the
465  // map, it won't be replaced, but if it's new it will be made
466  // the key. Also we remove the key before making the vector empty.
467  // This way we ensure that the eg in the key is always the first
468  // element of the vector.
469  std::vector<NnetChainExample*> &vec = eg_to_egs_[eg];
470  vec.push_back(eg);
471  int32 eg_size = GetNnetChainExampleSize(*eg),
472  num_available = vec.size();
473  bool input_ended = false;
474  int32 minibatch_size = config_.MinibatchSize(eg_size, num_available,
475  input_ended);
476  if (minibatch_size != 0) { // we need to write out a merged eg.
477  KALDI_ASSERT(minibatch_size == num_available);
478 
479  std::vector<NnetChainExample*> vec_copy(vec);
480  eg_to_egs_.erase(eg);
481 
482  // MergeChainExamples() expects a vector of NnetChainExample, not of pointers,
483  // so use swap to create that without doing any real work.
484  std::vector<NnetChainExample> egs_to_merge(minibatch_size);
485  for (int32 i = 0; i < minibatch_size; i++) {
486  egs_to_merge[i].Swap(vec_copy[i]);
487  delete vec_copy[i]; // we owned those pointers.
488  }
489  WriteMinibatch(&egs_to_merge);
490  }
491 }
int32 GetNnetChainExampleSize(const NnetChainExample &a)
const ExampleMergingConfig & config_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 MinibatchSize(int32 size_of_eg, int32 num_available_egs, bool input_ended) const
This function tells you what minibatch size should be used for this eg.
void WriteMinibatch(std::vector< NnetChainExample > *egs)
int32 ExitStatus ( )
inline
void Finish ( )

Definition at line 508 of file nnet-chain-example.cc.

References ChainExampleMerger::config_, ExampleMergingStats::DiscardedExamples(), ChainExampleMerger::eg_to_egs_, ChainExampleMerger::finished_, kaldi::nnet3::GetNnetChainExampleSize(), rnnlm::i, KALDI_ASSERT, ExampleMergingConfig::MinibatchSize(), ExampleMergingStats::PrintStats(), ChainExampleMerger::stats_, and ChainExampleMerger::WriteMinibatch().

Referenced by ChainExampleMerger::ExitStatus(), and ChainExampleMerger::~ChainExampleMerger().

508  {
509  if (finished_) return; // already finished.
510  finished_ = true;
511 
512  // we'll convert the map eg_to_egs_ to a vector of vectors to avoid
513  // iterator invalidation problems.
514  std::vector<std::vector<NnetChainExample*> > all_egs;
515  all_egs.reserve(eg_to_egs_.size());
516 
517  MapType::iterator iter = eg_to_egs_.begin(), end = eg_to_egs_.end();
518  for (; iter != end; ++iter)
519  all_egs.push_back(iter->second);
520  eg_to_egs_.clear();
521 
522  for (size_t i = 0; i < all_egs.size(); i++) {
523  int32 minibatch_size;
524  std::vector<NnetChainExample*> &vec = all_egs[i];
525  KALDI_ASSERT(!vec.empty());
526  int32 eg_size = GetNnetChainExampleSize(*(vec[0]));
527  bool input_ended = true;
528  while (!vec.empty() &&
529  (minibatch_size = config_.MinibatchSize(eg_size, vec.size(),
530  input_ended)) != 0) {
531  // MergeChainExamples() expects a vector of
532  // NnetChainExample, not of pointers, so use swap to create that
533  // without doing any real work.
534  std::vector<NnetChainExample> egs_to_merge(minibatch_size);
535  for (int32 i = 0; i < minibatch_size; i++) {
536  egs_to_merge[i].Swap(vec[i]);
537  delete vec[i]; // we owned those pointers.
538  }
539  vec.erase(vec.begin(), vec.begin() + minibatch_size);
540  WriteMinibatch(&egs_to_merge);
541  }
542  if (!vec.empty()) {
543  int32 eg_size = GetNnetChainExampleSize(*(vec[0]));
544  NnetChainExampleStructureHasher eg_hasher;
545  size_t structure_hash = eg_hasher(*(vec[0]));
546  int32 num_discarded = vec.size();
547  stats_.DiscardedExamples(eg_size, structure_hash, num_discarded);
548  for (int32 i = 0; i < num_discarded; i++)
549  delete vec[i];
550  vec.clear();
551  }
552  }
553  stats_.PrintStats();
554 }
void DiscardedExamples(int32 example_size, size_t structure_hash, int32 num_discarded)
Users call this function to inform this class that after processing all the data, for examples of ori...
int32 GetNnetChainExampleSize(const NnetChainExample &a)
void PrintStats() const
Calling this will cause a log message with information about the examples to be printed.
const ExampleMergingConfig & config_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 MinibatchSize(int32 size_of_eg, int32 num_available_egs, bool input_ended) const
This function tells you what minibatch size should be used for this eg.
void WriteMinibatch(std::vector< NnetChainExample > *egs)
void WriteMinibatch ( std::vector< NnetChainExample > *  egs)
private

Definition at line 493 of file nnet-chain-example.cc.

References ExampleMergingConfig::compress, ChainExampleMerger::config_, kaldi::nnet3::GetNnetChainExampleSize(), KALDI_ASSERT, kaldi::nnet3::MergeChainExamples(), ChainExampleMerger::num_egs_written_, ChainExampleMerger::stats_, TableWriter< Holder >::Write(), ChainExampleMerger::writer_, and ExampleMergingStats::WroteExample().

Referenced by ChainExampleMerger::AcceptExample(), and ChainExampleMerger::Finish().

494  {
495  KALDI_ASSERT(!egs->empty());
496  int32 eg_size = GetNnetChainExampleSize((*egs)[0]);
497  NnetChainExampleStructureHasher eg_hasher;
498  size_t structure_hash = eg_hasher((*egs)[0]);
499  int32 minibatch_size = egs->size();
500  stats_.WroteExample(eg_size, structure_hash, minibatch_size);
501  NnetChainExample merged_eg;
502  MergeChainExamples(config_.compress, egs, &merged_eg);
503  std::ostringstream key;
504  key << "merged-" << (num_egs_written_++) << "-" << minibatch_size;
505  writer_->Write(key.str(), merged_eg);
506 }
void Write(const std::string &key, const T &value) const
void MergeChainExamples(bool compress, std::vector< NnetChainExample > *input, NnetChainExample *output)
This function merges a list of NnetChainExample objects into a single one– intended to be used when ...
int32 GetNnetChainExampleSize(const NnetChainExample &a)
void WroteExample(int32 example_size, size_t structure_hash, int32 minibatch_size)
Users call this function to inform this class that one minibatch has been written aggregating 'miniba...
const ExampleMergingConfig & config_
NnetChainExampleWriter * writer_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169

Member Data Documentation

MapType eg_to_egs_
private
bool finished_
private
int32 num_egs_written_
private
NnetChainExampleWriter* writer_
private

Definition at line 261 of file nnet-chain-example.h.

Referenced by ChainExampleMerger::WriteMinibatch().


The documentation for this class was generated from the following files: