nnet-update.h
Go to the documentation of this file.
1 // nnet2/nnet-update.h
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 // 2014 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET2_NNET_UPDATE_H_
22 #define KALDI_NNET2_NNET_UPDATE_H_
23 
24 #include "nnet2/nnet-nnet.h"
25 #include "nnet2/nnet-example.h"
26 #include "util/table-types.h"
27 
28 
29 namespace kaldi {
30 namespace nnet2 {
31 
39 class NnetEnsembleTrainer;
40 
41 // This class NnetUpdater contains functions for updating the neural net or
42 // computing its gradient, given a set of NnetExamples. We
43 // define it in the header file becaused it's needed by the ensemble training.
44 // But in normal cases its functionality should be used by calling DoBackprop(),
45 // and by ComputeNnetObjf()
46 class NnetUpdater {
47  public:
48  // Note: in the case of training with SGD, "nnet" and "nnet_to_update" will
49  // be identical. They'll be different if we're accumulating the gradient
50  // for a held-out set and don't want to update the model. Note: nnet_to_update
51  // may be NULL if you don't want do do backprop.
52  NnetUpdater(const Nnet &nnet,
53  Nnet *nnet_to_update);
54 
58  double ComputeForMinibatch(const std::vector<NnetExample> &data,
59  double *tot_accuracy);
60 
67  double ComputeForMinibatch(const std::vector<NnetExample> &data,
68  Matrix<BaseFloat> *formatted_data,
69  double *tot_accuracy);
70 
71  void GetOutput(CuMatrix<BaseFloat> *output);
72  protected:
73 
74  void Propagate();
75 
78  void FormatInput(const std::vector<NnetExample> &data);
79 
86  double ComputeObjfAndDeriv(const std::vector<NnetExample> &data,
87  CuMatrix<BaseFloat> *deriv,
88  double *tot_accuracy = NULL) const;
89 
90 
96  void Backprop(CuMatrix<BaseFloat> *deriv) const;
97 
98  friend class NnetEnsembleTrainer;
99  private:
100  // Must be called after Propagate().
101  double ComputeTotAccuracy(const std::vector<NnetExample> &data) const;
102 
103  const Nnet &nnet_;
105  int32 num_chunks_; // same as the minibatch size.
106  std::vector<ChunkInfo> chunk_info_out_;
107 
108  std::vector<CuMatrix<BaseFloat> > forward_data_; // The forward data
109  // for the outputs of each of the components.
110 
111 };
112 
113 
121 void FormatNnetInput(const Nnet &nnet,
122  const std::vector<NnetExample> &data,
123  Matrix<BaseFloat> *mat);
124 
125 
135 double DoBackprop(const Nnet &nnet,
136  const std::vector<NnetExample> &examples,
137  Nnet *nnet_to_update,
138  double *tot_accuracy = NULL);
139 
147 double DoBackprop(const Nnet &nnet,
148  const std::vector<NnetExample> &examples,
149  Matrix<BaseFloat> *examples_formatted,
150  Nnet *nnet_to_update,
151  double *tot_accuracy = NULL);
152 
153 
154 
157 BaseFloat TotalNnetTrainingWeight(const std::vector<NnetExample> &egs);
158 
163 double ComputeNnetObjf(const Nnet &nnet,
164  const std::vector<NnetExample> &examples,
165  double *tot_accuracy= NULL);
166 
172 double ComputeNnetObjf(const Nnet &nnet,
173  const std::vector<NnetExample> &examples,
174  int32 minibatch_size,
175  double *tot_accuracy= NULL);
176 
177 
181 double ComputeNnetGradient(
182  const Nnet &nnet,
183  const std::vector<NnetExample> &examples,
184  int32 batch_size,
185  Nnet *gradient);
186 
187 
188 } // namespace nnet2
189 } // namespace kaldi
190 
191 #endif // KALDI_NNET2_NNET_UPDATE_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Backprop(CuMatrix< BaseFloat > *deriv) const
Backprop must be called after ComputeObjfAndDeriv.
Definition: nnet-update.cc:188
double ComputeNnetGradient(const Nnet &nnet, const std::vector< NnetExample > &validation_set, int32 batch_size, Nnet *gradient)
ComputeNnetGradient is mostly used to compute gradients on validation sets; it divides the example in...
Definition: nnet-update.cc:302
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
double ComputeNnetObjf(const Nnet &nnet, const std::vector< NnetExample > &examples, double *tot_accuracy)
Computes objective function over a minibatch.
Definition: nnet-update.cc:258
void FormatInput(const std::vector< NnetExample > &data)
Formats the input as a single matrix and sets the size of forward_data_, and sets up chunk_info_out_...
Definition: nnet-update.cc:35
double DoBackprop(const Nnet &nnet, const std::vector< NnetExample > &examples, Nnet *nnet_to_update, double *tot_accuracy)
This function computes the objective function and either updates the model or adds to parameter gradi...
Definition: nnet-update.cc:265
void FormatNnetInput(const Nnet &nnet, const std::vector< NnetExample > &data, Matrix< BaseFloat > *input_mat)
Takes the input to the nnet for a minibatch of examples, and formats as a single matrix.
Definition: nnet-update.cc:207
std::vector< CuMatrix< BaseFloat > > forward_data_
Definition: nnet-update.h:108
double ComputeTotAccuracy(const std::vector< NnetExample > &data) const
Definition: nnet-update.cc:161
BaseFloat TotalNnetTrainingWeight(const std::vector< NnetExample > &egs)
Returns the total weight summed over all the examples...
Definition: nnet-update.cc:248
std::vector< ChunkInfo > chunk_info_out_
Definition: nnet-update.h:106
NnetUpdater(const Nnet &nnet, Nnet *nnet_to_update)
Definition: nnet-update.cc:28
double ComputeForMinibatch(const std::vector< NnetExample > &data, double *tot_accuracy)
Does the entire forward and backward computation for this minbatch.
Definition: nnet-update.cc:46
double ComputeObjfAndDeriv(const std::vector< NnetExample > &data, CuMatrix< BaseFloat > *deriv, double *tot_accuracy=NULL) const
Computes objective function and derivative at output layer, but does not do the backprop [for that...
Definition: nnet-update.cc:125
void GetOutput(CuMatrix< BaseFloat > *output)
Definition: nnet-update.cc:91