All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
NnetEnsembleTrainer Class Reference

#include <train-nnet-ensemble.h>

Collaboration diagram for NnetEnsembleTrainer:

Public Member Functions

 NnetEnsembleTrainer (const NnetEnsembleTrainerConfig &config, std::vector< Nnet * > nnet_ensemble)
 
void TrainOnExample (const NnetExample &value)
 TrainOnExample will take the example and add it to a buffer; if we've reached the minibatch size it will do the training. More...
 
 ~NnetEnsembleTrainer ()
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (NnetEnsembleTrainer)
 
void TrainOneMinibatch ()
 
void BeginNewPhase (bool first_time)
 

Private Attributes

NnetEnsembleTrainerConfig config_
 
std::vector< Nnet * > nnet_ensemble_
 
std::vector< NnetUpdater * > updater_ensemble_
 
int32 num_phases_
 
int32 minibatches_seen_this_phase_
 
std::vector< NnetExamplebuffer_
 
double beta_
 
double avg_logprob_this_phase_
 
double count_this_phase_
 

Detailed Description

Definition at line 64 of file train-nnet-ensemble.h.

Constructor & Destructor Documentation

Definition at line 127 of file train-nnet-ensemble.cc.

References NnetEnsembleTrainer::BeginNewPhase(), NnetEnsembleTrainer::buffer_, KALDI_LOG, NnetEnsembleTrainer::minibatches_seen_this_phase_, and NnetEnsembleTrainer::TrainOneMinibatch().

127  {
128  if (!buffer_.empty()) {
129  KALDI_LOG << "Doing partial minibatch of size "
130  << buffer_.size();
132  if (minibatches_seen_this_phase_ != 0) {
133  bool first_time = false;
134  BeginNewPhase(first_time);
135  }
136  }
137 }
std::vector< NnetExample > buffer_
#define KALDI_LOG
Definition: kaldi-error.h:133

Member Function Documentation

void BeginNewPhase ( bool  first_time)
private
KALDI_DISALLOW_COPY_AND_ASSIGN ( NnetEnsembleTrainer  )
private
void TrainOneMinibatch ( )
private

Definition at line 50 of file train-nnet-ensemble.cc.

References NnetEnsembleTrainer::avg_logprob_this_phase_, NnetEnsembleTrainer::BeginNewPhase(), NnetEnsembleTrainer::beta_, NnetEnsembleTrainer::buffer_, NnetEnsembleTrainer::config_, NnetEnsembleTrainer::count_this_phase_, rnnlm::i, CuMatrixBase< Real >::InvertElements(), KALDI_ASSERT, kaldi::nnet2::MakePair(), NnetEnsembleTrainerConfig::minibatches_per_phase, NnetEnsembleTrainer::minibatches_seen_this_phase_, CuMatrixBase< Real >::MulElements(), NnetEnsembleTrainer::nnet_ensemble_, and NnetEnsembleTrainer::updater_ensemble_.

Referenced by NnetEnsembleTrainer::TrainOnExample(), and NnetEnsembleTrainer::~NnetEnsembleTrainer().

50  {
51  KALDI_ASSERT(!buffer_.empty());
52 
53  int32 num_states = nnet_ensemble_[0]->GetComponent(nnet_ensemble_[0]->NumComponents() - 1).OutputDim();
54  // average of posteriors matrix, storing averaged outputs of net ensemble.
55  CuMatrix<BaseFloat> post_avg(buffer_.size(), num_states);
56  updater_ensemble_.reserve(nnet_ensemble_.size());
57  std::vector<CuMatrix<BaseFloat> > post_mat;
58  post_mat.resize(nnet_ensemble_.size());
59  for (int32 i = 0; i < nnet_ensemble_.size(); i++) {
60  updater_ensemble_.push_back(new NnetUpdater(*(nnet_ensemble_[i]), nnet_ensemble_[i]));
61  updater_ensemble_[i]->FormatInput(buffer_);
62  updater_ensemble_[i]->Propagate();
63  // posterior matrix, storing output of one net.
64  updater_ensemble_[i]->GetOutput(&post_mat[i]);
65  CuVector<BaseFloat> row_sum(post_mat[i].NumRows());
66  post_avg.AddMat(1.0, post_mat[i]);
67  }
68 
69  // calculate the interpolated posterios as new supervision labels, and also
70  // collect the indices of the original supervision labels for later use (calc. objf.).
71  std::vector<MatrixElement<BaseFloat> > sv_labels;
72  std::vector<Int32Pair > sv_labels_ind;
73  sv_labels.reserve(buffer_.size()); // We must have at least this many labels.
74  sv_labels_ind.reserve(buffer_.size()); // We must have at least this many labels.
75  for (int32 m = 0; m < buffer_.size(); m++) {
76  KALDI_ASSERT(buffer_[m].labels.size() == 1 &&
77  "Currently this code only supports single-frame egs.");
78  const std::vector<std::pair<int32,BaseFloat> > &labels = buffer_[m].labels[0];
79  for (size_t i = 0; i < labels.size(); i++) {
81  tmp = {m, labels[i].first, labels[i].second};
82  sv_labels.push_back(tmp);
83  sv_labels_ind.push_back(MakePair(m, labels[i].first));
84  }
85  }
86  post_avg.Scale(1.0 / nnet_ensemble_.size());
87  post_avg.Scale(beta_);
88  post_avg.AddElements(1.0, sv_labels);
89 
90  // calculate the deriv, do backprop, and calculate the objf.
91  for (int32 i = 0; i < nnet_ensemble_.size(); i++) {
92  CuMatrix<BaseFloat> tmp_deriv(post_mat[i]);
93  post_mat[i].ApplyLog();
94  std::vector<BaseFloat> log_post_correct;
95  log_post_correct.resize(sv_labels_ind.size());
96  post_mat[i].Lookup(sv_labels_ind, &(log_post_correct[0]));
97  BaseFloat log_prob_this_net = std::accumulate(log_post_correct.begin(),
98  log_post_correct.end(),
99  static_cast<BaseFloat>(0));
100  avg_logprob_this_phase_ += log_prob_this_net;
101  tmp_deriv.InvertElements();
102  tmp_deriv.MulElements(post_avg);
103  updater_ensemble_[i]->Backprop(&tmp_deriv);
104  }
105  count_this_phase_ += buffer_.size();
106  buffer_.clear();
109  avg_logprob_this_phase_ /= static_cast<BaseFloat>(nnet_ensemble_.size());
110  bool first_time = false;
111  BeginNewPhase(first_time);
112  }
113 }
static Int32Pair MakePair(int32 first, int32 second)
std::vector< NnetExample > buffer_
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< NnetUpdater * > updater_ensemble_
NnetEnsembleTrainerConfig config_
void TrainOnExample ( const NnetExample value)

TrainOnExample will take the example and add it to a buffer; if we've reached the minibatch size it will do the training.

Definition at line 44 of file train-nnet-ensemble.cc.

References NnetEnsembleTrainer::buffer_, NnetEnsembleTrainer::config_, NnetEnsembleTrainerConfig::minibatch_size, and NnetEnsembleTrainer::TrainOneMinibatch().

Referenced by main().

44  {
45  buffer_.push_back(value);
46  if (static_cast<int32>(buffer_.size()) == config_.minibatch_size)
48 }
std::vector< NnetExample > buffer_
NnetEnsembleTrainerConfig config_

Member Data Documentation

double avg_logprob_this_phase_
private
double beta_
private
double count_this_phase_
private
int32 minibatches_seen_this_phase_
private
std::vector<Nnet*> nnet_ensemble_
private

Definition at line 86 of file train-nnet-ensemble.h.

Referenced by NnetEnsembleTrainer::TrainOneMinibatch().

int32 num_phases_
private
std::vector<NnetUpdater*> updater_ensemble_
private

Definition at line 87 of file train-nnet-ensemble.h.

Referenced by NnetEnsembleTrainer::TrainOneMinibatch().


The documentation for this class was generated from the following files: