train-nnet-ensemble.h
Go to the documentation of this file.
1 // nnet2/train-nnet-ensemble.h
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 // 2014 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET2_TRAIN_NNET_ENSEMBLE_H_
22 #define KALDI_NNET2_TRAIN_NNET_ENSEMBLE_H_
23 
24 #include "nnet2/nnet-update.h"
25 #include "nnet2/nnet-compute.h"
26 #include "itf/options-itf.h"
27 
28 namespace kaldi {
29 namespace nnet2 {
30 
31 
35  double beta;
36 
37  NnetEnsembleTrainerConfig(): minibatch_size(500),
38  minibatches_per_phase(50),
39  beta(0.5) { }
40 
41  void Register (OptionsItf *opts) {
42  opts->Register("minibatch-size", &minibatch_size,
43  "Number of samples per minibatch of training data.");
44  opts->Register("minibatches-per-phase", &minibatches_per_phase,
45  "Number of minibatches to wait before printing training-set "
46  "objective.");
47  opts->Register("beta", &beta,
48  "weight of the second term in the objf, which is the cross-entropy "
49  "between the output posteriors and the averaged posteriors from other nets.");
50  }
51 };
52 
53 
54 // Similar as NnetTrainer, Class NnetEnsembleTrainer first batches
55 // up the input into minibatches and feed the data into every nnet in
56 // the ensemble, call Propogate to do forward propogation, and
57 // collect the output posteriors. The posteriors from different
58 // nets are averaged and then used to compute the additional term
59 // in the objf: (a constant times) the cross-entropy between each
60 // net's output posteriors and the averaged posteriors of
61 // the whole nnet ensemble. We also calculate the derivs and
62 // then call Backprop() to update each net separately.
63 
65  public:
67  std::vector<Nnet*> nnet_ensemble);
68 
71  void TrainOnExample(const NnetExample &value);
72 
74  private:
76 
77  void TrainOneMinibatch();
78 
79  // The following function is called by TrainOneMinibatch()
80  // when we enter a new phase.
81  void BeginNewPhase(bool first_time);
82 
83  // Things we were given in the initializer:
85 
86  std::vector<Nnet*> nnet_ensemble_; // the nnet ensemble we're training.
87  std::vector<NnetUpdater*> updater_ensemble_;
88 
89  // State information:
92  std::vector<NnetExample> buffer_;
93 
94  // ratio of the supervision, when interpolating the supervision with the averaged posteriors.
95  double beta_;
96  double avg_logprob_this_phase_; // Needed for accumulating train log-prob on each phase.
97  double count_this_phase_; // count corresponding to the above.
98 };
99 
100 
101 
102 } // namespace nnet2
103 } // namespace kaldi
104 
105 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:36
std::vector< NnetExample > buffer_
kaldi::int32 int32
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
This header provides functionality for sample-by-sample stochastic gradient descent and gradient comp...
std::vector< NnetUpdater * > updater_ensemble_
NnetEnsembleTrainerConfig config_