nnet-modify-learning-rates.cc File Reference
Include dependency graph for nnet-modify-learning-rates.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet2
 

Functions

void SetMaxChange (BaseFloat max_change, Nnet *nnet)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 41 of file nnet-modify-learning-rates.cc.

References Nnet::AddNnet(), VectorBase< Real >::ApplyPow(), Nnet::ComponentDotProducts(), VectorBase< Real >::DivElements(), kaldi::Exp(), ParseOptions::GetArg(), Nnet::GetLearningRates(), AmNnet::GetNnet(), ParseOptions::GetOptArg(), Nnet::GetParameterDim(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, ParseOptions::NumArgs(), Nnet::NumUpdatableComponents(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), ParseOptions::Register(), VectorBase< Real >::Scale(), Nnet::ScaleComponents(), Nnet::SetLearningRates(), kaldi::nnet2::SetMaxChange(), Output::Stream(), Input::Stream(), VectorBase< Real >::SumLog(), AmNnet::Write(), and TransitionModel::Write().

41  {
42  try {
43  using namespace kaldi;
44  using namespace kaldi::nnet2;
45  typedef kaldi::int32 int32;
46  typedef kaldi::int64 int64;
47 
48  const char *usage =
49  "This program modifies the learning rates so as to equalize the\n"
50  "relative changes in parameters for each layer, while keeping their\n"
51  "geometric mean the same (or changing it to a value specified using\n"
52  "the --average-learning-rate option).\n"
53  "\n"
54  "Usage: nnet-modify-learning-rates [options] <prev-model> \\\n"
55  " <cur-model> <modified-cur-model>\n"
56  "e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n"
57  " 5.mdl 6.mdl 6.mdl\n";
58 
59  bool binary_write = true;
60  bool retroactive = false;
61  BaseFloat average_learning_rate = 0.0;
62  BaseFloat first_layer_factor = 1.0;
63  BaseFloat last_layer_factor = 1.0;
64 
65  ParseOptions po(usage);
66  po.Register("binary", &binary_write, "Write output in binary mode");
67  po.Register("average-learning-rate", &average_learning_rate,
68  "If supplied, change learning rate geometric mean to the given "
69  "value.");
70  po.Register("first-layer-factor", &first_layer_factor, "Factor that "
71  "reduces the target relative learning rate for first layer.");
72  po.Register("last-layer-factor", &last_layer_factor, "Factor that "
73  "reduces the target relative learning rate for last layer.");
74  po.Register("retroactive", &retroactive, "If true, scale the parameter "
75  "differences as well.");
76 
77  po.Read(argc, argv);
78 
79  if (po.NumArgs() != 3) {
80  po.PrintUsage();
81  exit(1);
82  }
83 
84  KALDI_ASSERT(average_learning_rate >= 0);
85 
86  std::string prev_nnet_rxfilename = po.GetArg(1),
87  cur_nnet_rxfilename = po.GetArg(2),
88  modified_cur_nnet_rxfilename = po.GetOptArg(3);
89 
90  TransitionModel trans_model;
91  AmNnet am_prev_nnet, am_cur_nnet;
92  {
93  bool binary_read;
94  Input ki(prev_nnet_rxfilename, &binary_read);
95  trans_model.Read(ki.Stream(), binary_read);
96  am_prev_nnet.Read(ki.Stream(), binary_read);
97  }
98  {
99  bool binary_read;
100  Input ki(cur_nnet_rxfilename, &binary_read);
101  trans_model.Read(ki.Stream(), binary_read);
102  am_cur_nnet.Read(ki.Stream(), binary_read);
103  }
104 
105  if (am_prev_nnet.GetNnet().GetParameterDim() !=
106  am_cur_nnet.GetNnet().GetParameterDim()) {
107  KALDI_WARN << "Parameter-dim mismatch, cannot equalize the relative "
108  << "changes in parameters for each layer.";
109  exit(0);
110  }
111 
112  int32 ret = 0;
113 
114  // Gets relative parameter differences.
115  int32 num_updatable = am_prev_nnet.GetNnet().NumUpdatableComponents();
116  Vector<BaseFloat> relative_diff(num_updatable);
117  {
118  Nnet diff_nnet(am_prev_nnet.GetNnet());
119  diff_nnet.AddNnet(-1.0, am_cur_nnet.GetNnet());
120  diff_nnet.ComponentDotProducts(diff_nnet, &relative_diff);
121  relative_diff.ApplyPow(0.5);
122  Vector<BaseFloat> baseline_prod(num_updatable);
123  am_prev_nnet.GetNnet().ComponentDotProducts(am_prev_nnet.GetNnet(),
124  &baseline_prod);
125  baseline_prod.ApplyPow(0.5);
126  relative_diff.DivElements(baseline_prod);
127  KALDI_LOG << "Relative parameter differences per layer are "
128  << relative_diff;
129 
130  // If relative parameter difference for a certain is zero, set it to the
131  // mean of the rest values.
132  int32 num_zero = 0;
133  for (int32 i = 0; i < num_updatable; i++) {
134  if (relative_diff(i) == 0.0) {
135  num_zero++;
136  }
137  }
138  if (num_zero > 0) {
139  BaseFloat average_diff = relative_diff.Sum()
140  / static_cast<BaseFloat>(num_updatable - num_zero);
141  for (int32 i = 0; i < num_updatable; i++) {
142  if (relative_diff(i) == 0.0) {
143  relative_diff(i) = average_diff;
144  }
145  }
146  KALDI_LOG << "Zeros detected in the relative parameter difference "
147  << "vector, updating the vector to " << relative_diff;
148  }
149  }
150 
151  // Gets learning rates for previous neural net.
152  Vector<BaseFloat> prev_nnet_learning_rates(num_updatable),
153  cur_nnet_learning_rates(num_updatable);
154  am_prev_nnet.GetNnet().GetLearningRates(&prev_nnet_learning_rates);
155  am_cur_nnet.GetNnet().GetLearningRates(&cur_nnet_learning_rates);
156  KALDI_LOG << "Learning rates for previous model per layer are "
157  << prev_nnet_learning_rates;
158  KALDI_LOG << "Learning rates for current model per layer are "
159  << cur_nnet_learning_rates;
160 
161  // Gets target geometric mean.
162  BaseFloat target_geometric_mean = 0.0;
163  if (average_learning_rate == 0.0) {
164  target_geometric_mean = Exp(cur_nnet_learning_rates.SumLog()
165  / static_cast<BaseFloat>(num_updatable));
166  } else {
167  target_geometric_mean = average_learning_rate;
168  }
169  KALDI_ASSERT(target_geometric_mean > 0.0);
170 
171  // Works out the new learning rates. We start from the previous model;
172  // this ensures that if this program is run twice, we get consistent
173  // results even if it's overwritten the current model.
174  Vector<BaseFloat> nnet_learning_rates(prev_nnet_learning_rates);
175  nnet_learning_rates.DivElements(relative_diff);
176  KALDI_ASSERT(last_layer_factor > 0.0);
177  nnet_learning_rates(num_updatable - 1) *= last_layer_factor;
178  KALDI_ASSERT(first_layer_factor > 0.0);
179  nnet_learning_rates(0) *= first_layer_factor;
180  BaseFloat cur_geometric_mean = Exp(nnet_learning_rates.SumLog()
181  / static_cast<BaseFloat>(num_updatable));
182  nnet_learning_rates.Scale(target_geometric_mean / cur_geometric_mean);
183  KALDI_LOG << "New learning rates for current model per layer are "
184  << nnet_learning_rates;
185 
186  // Changes the parameter differences if --retroactivate is set to true.
187  if (retroactive) {
188  Vector<BaseFloat> scale_factors(nnet_learning_rates);
189  scale_factors.DivElements(prev_nnet_learning_rates);
190  am_cur_nnet.GetNnet().AddNnet(-1.0, am_prev_nnet.GetNnet());
191  am_cur_nnet.GetNnet().ScaleComponents(scale_factors);
192  am_cur_nnet.GetNnet().AddNnet(1.0, am_prev_nnet.GetNnet());
193  KALDI_LOG << "Scale parameter difference retroactively. Scaling factors "
194  << "are " << scale_factors;
195  }
196 
197  // Sets learning rates and writes updated model.
198  am_cur_nnet.GetNnet().SetLearningRates(nnet_learning_rates);
199 
200  SetMaxChange(0.0, &(am_cur_nnet.GetNnet()));
201 
202  Output ko(modified_cur_nnet_rxfilename, binary_write);
203  trans_model.Write(ko.Stream(), binary_write);
204  am_cur_nnet.Write(ko.Stream(), binary_write);
205 
206  return ret;
207  } catch(const std::exception &e) {
208  std::cerr << e.what() << '\n';
209  return -1;
210  }
211 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
double Exp(double x)
Definition: kaldi-math.h:83
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
Definition: nnet-nnet.cc:576
int32 NumUpdatableComponents() const
Returns the number of updatable components.
Definition: nnet-nnet.cc:413
void ComponentDotProducts(const Nnet &other, VectorBase< BaseFloat > *dot_prod) const
Definition: nnet-nnet.cc:207
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
virtual int32 GetParameterDim() const
Definition: nnet-nnet.cc:657
void GetLearningRates(VectorBase< BaseFloat > *learning_rates) const
Get all the learning rates in the neural net (the output must have dim equal to NumUpdatableComponent...
Definition: nnet-nnet.cc:476
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
#define KALDI_WARN
Definition: kaldi-error.h:150
void SetMaxChange(BaseFloat max_change, Nnet *nnet)
void ScaleComponents(const VectorBase< BaseFloat > &scales)
Scales the parameters of each of the updatable components.
Definition: nnet-nnet.cc:421
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
void SetLearningRates(BaseFloat learning_rates)
Set all the learning rates in the neural net to this value.
Definition: nnet-nnet.cc:346
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61