nnet-fix.cc
Go to the documentation of this file.
1 // nnet2/nnet-fix.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet2/nnet-fix.h"
21 
22 namespace kaldi {
23 namespace nnet2 {
24 
25 
26 /* See the header for what we're doing.
27  The pattern we're looking for is AffineComponent followed by
28  a NonlinearComponent of type SigmoidComponent or TanhComponent.
29 */
30 
31 void FixNnet(const NnetFixConfig &config, Nnet *nnet) {
32  for (int32 c = 0; c + 1 < nnet->NumComponents(); c++) {
33  AffineComponent *ac = dynamic_cast<AffineComponent*>(
34  &(nnet->GetComponent(c)));
35  NonlinearComponent *nc = dynamic_cast<NonlinearComponent*>(
36  &(nnet->GetComponent(c + 1)));
37  if (ac == NULL || nc == NULL) continue;
38  // We only want to process this if it's of type SigmoidComponent
39  // or TanhComponent.
40  BaseFloat max_deriv; // The maximum derivative of this nonlinearity.
41  bool is_relu = false;
42  {
43  SigmoidComponent *sc = dynamic_cast<SigmoidComponent*>(nc);
44  TanhComponent *tc = dynamic_cast<TanhComponent*>(nc);
45  RectifiedLinearComponent *rc = dynamic_cast<RectifiedLinearComponent*>(nc);
46  if (sc != NULL) max_deriv = 0.25;
47  else if (tc != NULL) max_deriv = 1.0;
48  else if (rc != NULL) { max_deriv = 1.0; is_relu = true; }
49  else continue; // E.g. SoftmaxComponent; we don't handle this.
50  }
51  double count = nc->Count();
52  Vector<double> deriv_sum (nc->DerivSum());
53  if (count == 0.0 || deriv_sum.Dim() == 0) {
54  KALDI_WARN << "Cannot fix neural net because no statistics are stored.";
55  continue;
56  }
57  Vector<BaseFloat> bias_params(ac->BiasParams());
58  Matrix<BaseFloat> linear_params(ac->LinearParams());
59  int32 dim = nc->InputDim(), num_small_deriv = 0, num_large_deriv = 0;
60  for (int32 d = 0; d < dim; d++) {
61  // deriv ratio is the ratio of the computed average derivative to the
62  // maximum derivative of that nonlinear function.
63  BaseFloat deriv_ratio = deriv_sum(d) / (count * max_deriv);
64  KALDI_ASSERT(deriv_ratio >= 0.0 && deriv_ratio < 1.01); // Or there is an
65  // error in the
66  // math.
67  if (deriv_ratio < config.min_average_deriv) {
68  // derivative is too small, meaning we've gone off into the "flat part"
69  // of the sigmoid (or for ReLU, we're always-off).
70  if (is_relu) {
71  bias_params(d) += config.relu_bias_change;
72  } else {
73  BaseFloat parameter_factor = std::min(config.min_average_deriv /
74  deriv_ratio,
75  config.parameter_factor);
76  // we need to reduce the parameters, so multiply by 1/parameter factor.
77  bias_params(d) *= 1.0 / parameter_factor;
78  linear_params.Row(d).Scale(1.0 / parameter_factor);
79  }
80  num_small_deriv++;
81  } else if (deriv_ratio > config.max_average_deriv) {
82  // derivative is too large, meaning we're only in the linear part of the
83  // sigmoid, in the middle. (or for ReLU, we're always-on.
84  if (is_relu) {
85  bias_params(d) -= config.relu_bias_change;
86  } else {
87  BaseFloat parameter_factor = std::min(deriv_ratio / config.max_average_deriv,
88  config.parameter_factor);
89  // we need to increase the factors, so multiply by parameter_factor.
90  bias_params(d) *= parameter_factor;
91  linear_params.Row(d).Scale(parameter_factor);
92  }
93  num_large_deriv++;
94  }
95  }
96  if (is_relu) {
97  KALDI_LOG << "For layer " << c << " (ReLU units), increased bias for "
98  << num_small_deriv << " indexes and decreased it for "
99  << num_large_deriv << ", out of a total of " << dim;
100  } else {
101  KALDI_LOG << "For layer " << c << ", decreased parameters for "
102  << num_small_deriv << " indexes, and increased them for "
103  << num_large_deriv << " out of a total of " << dim;
104  }
105  ac->SetParams(bias_params, linear_params);
106  }
107 }
108 
109 
110 } // namespace nnet2
111 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
virtual void SetParams(const VectorBase< BaseFloat > &bias, const MatrixBase< BaseFloat > &linear)
This kind of Component is a base-class for things like sigmoid and softmax.
const CuVector< BaseFloat > & BiasParams()
kaldi::int32 int32
void FixNnet(const NnetFixConfig &config, Nnet *nnet)
Definition: nnet-fix.cc:31
virtual int32 InputDim() const
Get size of input vectors.
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const CuVector< double > & DerivSum() const
const CuMatrix< BaseFloat > & LinearParams()
#define KALDI_LOG
Definition: kaldi-error.h:153