nnet-modify-learning-rates.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-modify-learning-rates.cc
2 
3 // Copyright 2013 Guoguo Chen
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
23 #include "nnet2/train-nnet.h"
24 #include "nnet2/am-nnet.h"
25 
26 
27 namespace kaldi {
28 namespace nnet2 {
29 void SetMaxChange(BaseFloat max_change, Nnet *nnet) {
30  for (int32 c = 0; c < nnet->NumComponents(); c++) {
31  Component *component = &(nnet->GetComponent(c));
33  dynamic_cast<AffineComponentPreconditioned*>(component);
34  if (ac != NULL)
35  ac->SetMaxChange(max_change);
36  }
37 }
38 }
39 }
40 
41 int main(int argc, char *argv[]) {
42  try {
43  using namespace kaldi;
44  using namespace kaldi::nnet2;
45  typedef kaldi::int32 int32;
46  typedef kaldi::int64 int64;
47 
48  const char *usage =
49  "This program modifies the learning rates so as to equalize the\n"
50  "relative changes in parameters for each layer, while keeping their\n"
51  "geometric mean the same (or changing it to a value specified using\n"
52  "the --average-learning-rate option).\n"
53  "\n"
54  "Usage: nnet-modify-learning-rates [options] <prev-model> \\\n"
55  " <cur-model> <modified-cur-model>\n"
56  "e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n"
57  " 5.mdl 6.mdl 6.mdl\n";
58 
59  bool binary_write = true;
60  bool retroactive = false;
61  BaseFloat average_learning_rate = 0.0;
62  BaseFloat first_layer_factor = 1.0;
63  BaseFloat last_layer_factor = 1.0;
64 
65  ParseOptions po(usage);
66  po.Register("binary", &binary_write, "Write output in binary mode");
67  po.Register("average-learning-rate", &average_learning_rate,
68  "If supplied, change learning rate geometric mean to the given "
69  "value.");
70  po.Register("first-layer-factor", &first_layer_factor, "Factor that "
71  "reduces the target relative learning rate for first layer.");
72  po.Register("last-layer-factor", &last_layer_factor, "Factor that "
73  "reduces the target relative learning rate for last layer.");
74  po.Register("retroactive", &retroactive, "If true, scale the parameter "
75  "differences as well.");
76 
77  po.Read(argc, argv);
78 
79  if (po.NumArgs() != 3) {
80  po.PrintUsage();
81  exit(1);
82  }
83 
84  KALDI_ASSERT(average_learning_rate >= 0);
85 
86  std::string prev_nnet_rxfilename = po.GetArg(1),
87  cur_nnet_rxfilename = po.GetArg(2),
88  modified_cur_nnet_rxfilename = po.GetOptArg(3);
89 
90  TransitionModel trans_model;
91  AmNnet am_prev_nnet, am_cur_nnet;
92  {
93  bool binary_read;
94  Input ki(prev_nnet_rxfilename, &binary_read);
95  trans_model.Read(ki.Stream(), binary_read);
96  am_prev_nnet.Read(ki.Stream(), binary_read);
97  }
98  {
99  bool binary_read;
100  Input ki(cur_nnet_rxfilename, &binary_read);
101  trans_model.Read(ki.Stream(), binary_read);
102  am_cur_nnet.Read(ki.Stream(), binary_read);
103  }
104 
105  if (am_prev_nnet.GetNnet().GetParameterDim() !=
106  am_cur_nnet.GetNnet().GetParameterDim()) {
107  KALDI_WARN << "Parameter-dim mismatch, cannot equalize the relative "
108  << "changes in parameters for each layer.";
109  exit(0);
110  }
111 
112  int32 ret = 0;
113 
114  // Gets relative parameter differences.
115  int32 num_updatable = am_prev_nnet.GetNnet().NumUpdatableComponents();
116  Vector<BaseFloat> relative_diff(num_updatable);
117  {
118  Nnet diff_nnet(am_prev_nnet.GetNnet());
119  diff_nnet.AddNnet(-1.0, am_cur_nnet.GetNnet());
120  diff_nnet.ComponentDotProducts(diff_nnet, &relative_diff);
121  relative_diff.ApplyPow(0.5);
122  Vector<BaseFloat> baseline_prod(num_updatable);
123  am_prev_nnet.GetNnet().ComponentDotProducts(am_prev_nnet.GetNnet(),
124  &baseline_prod);
125  baseline_prod.ApplyPow(0.5);
126  relative_diff.DivElements(baseline_prod);
127  KALDI_LOG << "Relative parameter differences per layer are "
128  << relative_diff;
129 
130  // If relative parameter difference for a certain is zero, set it to the
131  // mean of the rest values.
132  int32 num_zero = 0;
133  for (int32 i = 0; i < num_updatable; i++) {
134  if (relative_diff(i) == 0.0) {
135  num_zero++;
136  }
137  }
138  if (num_zero > 0) {
139  BaseFloat average_diff = relative_diff.Sum()
140  / static_cast<BaseFloat>(num_updatable - num_zero);
141  for (int32 i = 0; i < num_updatable; i++) {
142  if (relative_diff(i) == 0.0) {
143  relative_diff(i) = average_diff;
144  }
145  }
146  KALDI_LOG << "Zeros detected in the relative parameter difference "
147  << "vector, updating the vector to " << relative_diff;
148  }
149  }
150 
151  // Gets learning rates for previous neural net.
152  Vector<BaseFloat> prev_nnet_learning_rates(num_updatable),
153  cur_nnet_learning_rates(num_updatable);
154  am_prev_nnet.GetNnet().GetLearningRates(&prev_nnet_learning_rates);
155  am_cur_nnet.GetNnet().GetLearningRates(&cur_nnet_learning_rates);
156  KALDI_LOG << "Learning rates for previous model per layer are "
157  << prev_nnet_learning_rates;
158  KALDI_LOG << "Learning rates for current model per layer are "
159  << cur_nnet_learning_rates;
160 
161  // Gets target geometric mean.
162  BaseFloat target_geometric_mean = 0.0;
163  if (average_learning_rate == 0.0) {
164  target_geometric_mean = Exp(cur_nnet_learning_rates.SumLog()
165  / static_cast<BaseFloat>(num_updatable));
166  } else {
167  target_geometric_mean = average_learning_rate;
168  }
169  KALDI_ASSERT(target_geometric_mean > 0.0);
170 
171  // Works out the new learning rates. We start from the previous model;
172  // this ensures that if this program is run twice, we get consistent
173  // results even if it's overwritten the current model.
174  Vector<BaseFloat> nnet_learning_rates(prev_nnet_learning_rates);
175  nnet_learning_rates.DivElements(relative_diff);
176  KALDI_ASSERT(last_layer_factor > 0.0);
177  nnet_learning_rates(num_updatable - 1) *= last_layer_factor;
178  KALDI_ASSERT(first_layer_factor > 0.0);
179  nnet_learning_rates(0) *= first_layer_factor;
180  BaseFloat cur_geometric_mean = Exp(nnet_learning_rates.SumLog()
181  / static_cast<BaseFloat>(num_updatable));
182  nnet_learning_rates.Scale(target_geometric_mean / cur_geometric_mean);
183  KALDI_LOG << "New learning rates for current model per layer are "
184  << nnet_learning_rates;
185 
186  // Changes the parameter differences if --retroactivate is set to true.
187  if (retroactive) {
188  Vector<BaseFloat> scale_factors(nnet_learning_rates);
189  scale_factors.DivElements(prev_nnet_learning_rates);
190  am_cur_nnet.GetNnet().AddNnet(-1.0, am_prev_nnet.GetNnet());
191  am_cur_nnet.GetNnet().ScaleComponents(scale_factors);
192  am_cur_nnet.GetNnet().AddNnet(1.0, am_prev_nnet.GetNnet());
193  KALDI_LOG << "Scale parameter difference retroactively. Scaling factors "
194  << "are " << scale_factors;
195  }
196 
197  // Sets learning rates and writes updated model.
198  am_cur_nnet.GetNnet().SetLearningRates(nnet_learning_rates);
199 
200  SetMaxChange(0.0, &(am_cur_nnet.GetNnet()));
201 
202  Output ko(modified_cur_nnet_rxfilename, binary_write);
203  trans_model.Write(ko.Stream(), binary_write);
204  am_cur_nnet.Write(ko.Stream(), binary_write);
205 
206  return ret;
207  } catch(const std::exception &e) {
208  std::cerr << e.what() << '\n';
209  return -1;
210  }
211 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
double Exp(double x)
Definition: kaldi-math.h:83
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
Definition: nnet-nnet.cc:576
int32 NumUpdatableComponents() const
Returns the number of updatable components.
Definition: nnet-nnet.cc:413
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.
void ComponentDotProducts(const Nnet &other, VectorBase< BaseFloat > *dot_prod) const
Definition: nnet-nnet.cc:207
Real SumLog() const
Returns sum of the logs of the elements.
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
int main(int argc, char *argv[])
kaldi::int32 int32
virtual int32 GetParameterDim() const
Definition: nnet-nnet.cc:657
void Register(const std::string &name, bool *ptr, const std::string &doc)
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
void GetLearningRates(VectorBase< BaseFloat > *learning_rates) const
Get all the learning rates in the neural net (the output must have dim equal to NumUpdatableComponent...
Definition: nnet-nnet.cc:476
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void Scale(Real alpha)
Multiplies all elements by this constant.
void SetMaxChange(BaseFloat max_change, Nnet *nnet)
void ScaleComponents(const VectorBase< BaseFloat > &scales)
Scales the parameters of each of the updatable components.
Definition: nnet-nnet.cc:421
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
void SetLearningRates(BaseFloat learning_rates)
Set all the learning rates in the neural net to this value.
Definition: nnet-nnet.cc:346
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void ApplyPow(Real power)
Take all elements of vector to a power.
Definition: kaldi-vector.h:179
void DivElements(const VectorBase< Real > &v)
Divide element-by-element by a vector.
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61
std::string GetOptArg(int param) const