shrink-nnet.cc
Go to the documentation of this file.
1 // nnet2/shrink-nnet.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet2/shrink-nnet.h"
21 
22 namespace kaldi {
23 namespace nnet2 {
24 
26  const std::vector<NnetExample> &validation_set,
27  const Vector<double> &log_scale_params,
28  const Nnet &nnet,
29  Vector<double> *gradient) {
30  Vector<BaseFloat> scale_params(log_scale_params);
31  scale_params.ApplyExp();
32  Nnet nnet_scaled(nnet);
33  nnet_scaled.ScaleComponents(scale_params);
34 
35  Nnet nnet_gradient(nnet);
36  bool is_gradient = true;
37  nnet_gradient.SetZero(is_gradient);
38 
39  // note: "ans" is normalized by the total weight of validation frames.
40  int32 batch_size = 1024;
41  BaseFloat ans = ComputeNnetGradient(nnet_scaled,
42  validation_set,
43  batch_size,
44  &nnet_gradient);
45 
46  BaseFloat tot_count = validation_set.size();
47  int32 i = 0; // index into log_scale_params.
48  for (int32 j = 0; j < nnet_scaled.NumComponents(); j++) {
49  const UpdatableComponent *uc =
50  dynamic_cast<const UpdatableComponent*>(&(nnet.GetComponent(j))),
51  *uc_gradient =
52  dynamic_cast<const UpdatableComponent*>(&(nnet_gradient.GetComponent(j)));
53  if (uc != NULL) {
54  BaseFloat dotprod = uc->DotProduct(*uc_gradient) / tot_count;
55  (*gradient)(i) = dotprod * scale_params(i); // gradient w.r.t log of scaling factor.
56  // We multiply by scale_params(i) to take into account d/dx exp(x); "gradient"
57  // is the gradient w.r.t. the log of the scale_params.
58  i++;
59  }
60  }
61  KALDI_ASSERT(i == log_scale_params.Dim());
62  return ans;
63 }
64 
65 
66 void ShrinkNnet(const NnetShrinkConfig &shrink_config,
67  const std::vector<NnetExample> &validation_set,
68  Nnet *nnet) {
69 
70  int32 dim = nnet->NumUpdatableComponents();
71  KALDI_ASSERT(dim > 0);
72  Vector<double> log_scale(dim), gradient(dim); // will be zero.
73 
74  // Get initial gradient.
75  double objf, initial_objf;
76 
77 
78  LbfgsOptions lbfgs_options;
79  lbfgs_options.minimize = false; // We're maximizing.
80  lbfgs_options.m = dim; // Store the same number of vectors as the dimension
81  // itself, so this is BFGS.
82  lbfgs_options.first_step_length = shrink_config.initial_step;
83 
84  OptimizeLbfgs<double> lbfgs(log_scale,
85  lbfgs_options);
86 
87  for (int32 i = 0; i < shrink_config.num_bfgs_iters; i++) {
88  log_scale.CopyFromVec(lbfgs.GetProposedValue());
89  objf = ComputeObjfAndGradient(validation_set, log_scale,
90  *nnet,
91  &gradient);
92 
93  KALDI_VLOG(2) << "log-scale = " << log_scale << ", objf = " << objf
94  << ", gradient = " << gradient;
95  if (i == 0) initial_objf = objf;
96 
97  lbfgs.DoStep(objf, gradient);
98  }
99 
100  log_scale.CopyFromVec(lbfgs.GetValue(&objf));
101 
102  Vector<BaseFloat> scale(log_scale);
103  scale.ApplyExp();
104  KALDI_LOG << "Shrinking nnet, validation objf per frame changed from "
105  << initial_objf << " to " << objf << ", scale factors per layer are "
106  << scale;
107  nnet->ScaleComponents(scale);
108 }
109 
110 
111 } // namespace nnet2
112 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
void DoStep(Real function_value, const VectorBase< Real > &gradient)
The user calls this function to provide the class with the function and gradient info at the point Ge...
void ApplyExp()
Apply exponential to each value in vector.
int32 NumUpdatableComponents() const
Returns the number of updatable components.
Definition: nnet-nnet.cc:413
double ComputeNnetGradient(const Nnet &nnet, const std::vector< NnetExample > &validation_set, int32 batch_size, Nnet *gradient)
ComputeNnetGradient is mostly used to compute gradients on validation sets; it divides the example in...
Definition: nnet-update.cc:302
kaldi::int32 int32
const VectorBase< Real > & GetValue(Real *objf_value=NULL) const
This returns the value of the variable x that has the best objective function so far, and the corresponding objective function value if requested.
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
Configuration class that controls neural net "shrinkage" which is actually a scaling on the parameter...
Definition: shrink-nnet.h:33
float BaseFloat
Definition: kaldi-types.h:29
void SetZero(bool treat_as_gradient)
Definition: nnet-nnet.cc:151
virtual BaseFloat DotProduct(const UpdatableComponent &other) const =0
Here, "other" is a component of the same specific type.
void ShrinkNnet(const NnetShrinkConfig &shrink_config, const std::vector< NnetExample > &validation_set, Nnet *nnet)
Definition: shrink-nnet.cc:66
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void ScaleComponents(const VectorBase< BaseFloat > &scales)
Scales the parameters of each of the updatable components.
Definition: nnet-nnet.cc:421
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
This is an implementation of L-BFGS.
Definition: optimization.h:84
#define KALDI_LOG
Definition: kaldi-error.h:153
const VectorBase< Real > & GetProposedValue() const
This returns the value at which the function wants us to compute the objective function and gradient...
Definition: optimization.h:134
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...
static BaseFloat ComputeObjfAndGradient(const std::vector< NnetExample > &validation_set, const Vector< double > &scale_params, const Nnet &orig_nnet, const Nnet &direction, Vector< double > *gradient)