combine-nnet-a.cc
Go to the documentation of this file.
1 // nnet2/combine-nnet-a.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet2/combine-nnet-a.h"
21 
22 namespace kaldi {
23 namespace nnet2 {
24 
25 /*
26  This function gets the "update direction". The vector "nnets" is
27  interpreted as (old-nnet new-nnet1 net-nnet2 ... new-nnetN), and
28  the "update direction" is the average of the new nnets, minus the
29  old nnet.
30 */
31 static void GetUpdateDirection(const std::vector<Nnet> &nnets,
32  Nnet *direction) {
33  KALDI_ASSERT(nnets.size() > 1);
34  int32 num_new_nnets = nnets.size() - 1;
35  Vector<BaseFloat> scales(nnets[0].NumUpdatableComponents());
36 
37  scales.Set(1.0 / num_new_nnets);
38 
39  *direction = nnets[1];
40  direction->ScaleComponents(scales); // first of the new nnets.
41  for (int32 n = 2; n < 1 + num_new_nnets; n++)
42  direction->AddNnet(scales, nnets[n]);
43  // now "direction" is the average of the new nnets. Subtract
44  // the old nnet's parameters.
45  scales.Set(-1.0);
46  direction->AddNnet(scales, nnets[0]);
47 }
48 
52 static void AddDirection(const Nnet &orig_nnet,
53  const Nnet &direction,
54  const VectorBase<BaseFloat> &scales,
55  Nnet *dest) {
56  *dest = orig_nnet;
57  dest->AddNnet(scales, direction);
58 }
59 
60 
62  const std::vector<NnetExample> &validation_set,
63  const Vector<double> &scale_params,
64  const Nnet &orig_nnet,
65  const Nnet &direction,
66  Vector<double> *gradient) {
67 
68  Vector<BaseFloat> scale_params_float(scale_params);
69 
70  Nnet nnet_combined;
71  AddDirection(orig_nnet, direction, scale_params_float, &nnet_combined);
72 
73  Nnet nnet_gradient(nnet_combined);
74  bool is_gradient = true;
75  nnet_gradient.SetZero(is_gradient);
76 
77  // note: "ans" is normalized by the total weight of validation frames.
78  int32 batch_size = 1024;
79  BaseFloat ans = ComputeNnetGradient(nnet_combined,
80  validation_set,
81  batch_size,
82  &nnet_gradient);
83 
84  BaseFloat tot_count = validation_set.size();
85  int32 i = 0; // index into scale_params.
86  for (int32 j = 0; j < nnet_combined.NumComponents(); j++) {
87  const UpdatableComponent *uc_direction =
88  dynamic_cast<const UpdatableComponent*>(&(direction.GetComponent(j))),
89  *uc_gradient =
90  dynamic_cast<const UpdatableComponent*>(&(nnet_gradient.GetComponent(j)));
91  if (uc_direction != NULL) {
92  BaseFloat dotprod = uc_direction->DotProduct(*uc_gradient) / tot_count;
93  (*gradient)(i) = dotprod;
94  i++;
95  }
96  }
97  KALDI_ASSERT(i == scale_params.Dim());
98  return ans;
99 }
100 
101 
103  const std::vector<NnetExample> &validation_set,
104  const std::vector<Nnet> &nnets,
105  Nnet *nnet_out) {
106 
107  Nnet direction; // the update direction = avg(nnets[1 ... N]) - nnets[0].
108  GetUpdateDirection(nnets, &direction);
109 
110  Vector<double> scale_params(nnets[0].NumUpdatableComponents()); // initial
111  // scale on "direction".
112 
113  int32 dim = scale_params.Dim();
114  KALDI_ASSERT(dim > 0);
115  Vector<double> gradient(dim);
116 
117  double objf, initial_objf, zero_objf;
118 
119  // Compute objf at zero; we don't actually need this gradient.
120  zero_objf = ComputeObjfAndGradient(validation_set,
121  scale_params,
122  nnets[0],
123  direction,
124  &gradient);
125  KALDI_LOG << "Objective function at old parameters is "
126  << zero_objf;
127 
128  scale_params.Set(1.0); // start optimization from the average of the parameters.
129 
130  LbfgsOptions lbfgs_options;
131  lbfgs_options.minimize = false; // We're maximizing.
132  lbfgs_options.m = dim; // Store the same number of vectors as the dimension
133  // itself, so this is BFGS.
134  lbfgs_options.first_step_length = config.initial_step;
135 
136  OptimizeLbfgs<double> lbfgs(scale_params,
137  lbfgs_options);
138 
139  for (int32 i = 0; i < config.num_bfgs_iters; i++) {
140  scale_params.CopyFromVec(lbfgs.GetProposedValue());
141  objf = ComputeObjfAndGradient(validation_set,
142  scale_params,
143  nnets[0],
144  direction,
145  &gradient);
146 
147  KALDI_VLOG(2) << "Iteration " << i << " scale-params = " << scale_params
148  << ", objf = " << objf << ", gradient = " << gradient;
149 
150  if (i == 0) initial_objf = objf;
151  lbfgs.DoStep(objf, gradient);
152  }
153 
154  scale_params.CopyFromVec(lbfgs.GetValue(&objf));
155 
156  KALDI_LOG << "Combining nnets, after BFGS, validation objf per frame changed from "
157  << zero_objf << " (no change), or " << initial_objf << " (default change), "
158  << " to " << objf << "; scale factors on update direction are "
159  << scale_params;
160 
161  BaseFloat objf_change = objf - zero_objf;
162  KALDI_ASSERT(objf_change >= 0.0); // This is guaranteed by the L-BFGS code.
163 
164  if (objf_change < config.valid_impr_thresh) {
165  // We'll overshoot. To have a smooth transition between the two regimes, if
166  // objf_change is close to valid_impr_thresh we don't overshoot as far.
167  BaseFloat overshoot = config.overshoot,
168  overshoot_max = config.valid_impr_thresh / objf_change; // >= 1.0.
169  if (overshoot_max < overshoot) {
170  KALDI_LOG << "Limiting overshoot from " << overshoot << " to " << overshoot_max
171  << " since the objf-impr " << objf_change << " is close to "
172  << "--valid-impr-thresh=" << config.valid_impr_thresh;
173  overshoot = overshoot_max;
174  }
175  KALDI_ASSERT(overshoot < 2.0 && "--valid-impr-thresh must be < 2.0 or "
176  "it will lead to instability.");
177  scale_params.Scale(overshoot);
178 
179  BaseFloat optimized_objf = objf;
180  objf = ComputeObjfAndGradient(validation_set,
181  scale_params,
182  nnets[0],
183  direction,
184  &gradient);
185 
186  KALDI_LOG << "Combining nnets, after overshooting, validation objf changed "
187  << "to " << objf << ". Note: (zero, start, optimized) objfs were "
188  << zero_objf << ", " << initial_objf << ", " << optimized_objf;
189  if (objf < zero_objf) {
190  // Note: this should not happen according to a quadratic approximation, and we
191  // expect this branch to be taken only rarely if at all.
192  KALDI_WARN << "After overshooting, objf was worse than not updating; not doing the "
193  << "overshoot. ";
194  scale_params.Scale(1.0 / overshoot);
195  }
196  } // Else don't do the "overshoot" stuff.
197 
198  Vector<BaseFloat> scale_params_float(scale_params);
199  // Output to "nnet_out":
200  AddDirection(nnets[0], direction, scale_params_float, nnet_out);
201 
202  // Now update the neural net learning rates.
203  int32 i = 0;
204  for (int32 j = 0; j < nnet_out->NumComponents(); j++) {
205  UpdatableComponent *uc =
206  dynamic_cast<UpdatableComponent*>(&(nnet_out->GetComponent(j)));
207  if (uc != NULL) {
208  BaseFloat step_length = scale_params(i), factor = step_length;
209  // Our basic rule is to update the learning rate by multiplying it
210  // by "step_lenght", but this is subject to certain limits.
211  if (factor < config.min_learning_rate_factor)
212  factor = config.min_learning_rate_factor;
213  if (factor > config.max_learning_rate_factor)
214  factor = config.max_learning_rate_factor;
215  BaseFloat new_learning_rate = factor * uc->LearningRate();
216  if (new_learning_rate < config.min_learning_rate)
217  new_learning_rate = config.min_learning_rate;
218  KALDI_LOG << "For component " << j << ", step length was " << step_length
219  << ", updating learning rate by factor " << factor << ", changing "
220  << "learning rate from " << uc->LearningRate() << " to "
221  << new_learning_rate;
222  uc->SetLearningRate(new_learning_rate);
223  i++;
224  }
225  }
226 }
227 
228 
229 } // namespace nnet2
230 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
void DoStep(Real function_value, const VectorBase< Real > &gradient)
The user calls this function to provide the class with the function and gradient info at the point Ge...
void CombineNnetsA(const NnetCombineAconfig &config, const std::vector< NnetExample > &validation_set, const std::vector< Nnet > &nnets, Nnet *nnet_out)
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
Definition: nnet-nnet.cc:576
static void AddDirection(const Nnet &orig_nnet, const Nnet &direction, const VectorBase< BaseFloat > &scales, Nnet *dest)
Sets "dest" to orig_nnet plus "direction", with each updatable component of "direction" first scaled ...
double ComputeNnetGradient(const Nnet &nnet, const std::vector< NnetExample > &validation_set, int32 batch_size, Nnet *gradient)
ComputeNnetGradient is mostly used to compute gradients on validation sets; it divides the example in...
Definition: nnet-update.cc:302
kaldi::int32 int32
const VectorBase< Real > & GetValue(Real *objf_value=NULL) const
This returns the value of the variable x that has the best objective function so far, and the corresponding objective function value if requested.
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
float BaseFloat
Definition: kaldi-types.h:29
void SetZero(bool treat_as_gradient)
Definition: nnet-nnet.cc:151
virtual BaseFloat DotProduct(const UpdatableComponent &other) const =0
Here, "other" is a component of the same specific type.
struct rnnlm::@11::@12 n
static void GetUpdateDirection(const std::vector< Nnet > &nnets, Nnet *direction)
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void ScaleComponents(const VectorBase< BaseFloat > &scales)
Scales the parameters of each of the updatable components.
Definition: nnet-nnet.cc:421
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Set(Real f)
Set all members of a vector to a specified value.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
This is an implementation of L-BFGS.
Definition: optimization.h:84
void SetLearningRate(BaseFloat lrate)
Sets the learning rate of gradient descent.
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
BaseFloat LearningRate() const
Gets the learning rate of gradient descent.
const VectorBase< Real > & GetProposedValue() const
This returns the value at which the function wants us to compute the objective function and gradient...
Definition: optimization.h:134
int32 NumUpdatableComponents(const Nnet &dest)
Returns the number of updatable components in the nnet.
Definition: nnet-utils.cc:422
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...
static BaseFloat ComputeObjfAndGradient(const std::vector< NnetExample > &validation_set, const Vector< double > &scale_params, const Nnet &orig_nnet, const Nnet &direction, Vector< double > *gradient)