43 using namespace kaldi;
46 typedef kaldi::int64 int64;
49 "This program modifies the learning rates so as to equalize the\n" 50 "relative changes in parameters for each layer, while keeping their\n" 51 "geometric mean the same (or changing it to a value specified using\n" 52 "the --average-learning-rate option).\n" 54 "Usage: nnet-modify-learning-rates [options] <prev-model> \\\n" 55 " <cur-model> <modified-cur-model>\n" 56 "e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n" 57 " 5.mdl 6.mdl 6.mdl\n";
59 bool binary_write =
true;
60 bool retroactive =
false;
66 po.Register(
"binary", &binary_write,
"Write output in binary mode");
67 po.Register(
"average-learning-rate", &average_learning_rate,
68 "If supplied, change learning rate geometric mean to the given " 70 po.Register(
"first-layer-factor", &first_layer_factor,
"Factor that " 71 "reduces the target relative learning rate for first layer.");
72 po.Register(
"last-layer-factor", &last_layer_factor,
"Factor that " 73 "reduces the target relative learning rate for last layer.");
74 po.Register(
"retroactive", &retroactive,
"If true, scale the parameter " 75 "differences as well.");
79 if (po.NumArgs() != 3) {
86 std::string prev_nnet_rxfilename = po.GetArg(1),
87 cur_nnet_rxfilename = po.GetArg(2),
88 modified_cur_nnet_rxfilename = po.GetOptArg(3);
91 AmNnet am_prev_nnet, am_cur_nnet;
94 Input ki(prev_nnet_rxfilename, &binary_read);
95 trans_model.
Read(ki.Stream(), binary_read);
96 am_prev_nnet.
Read(ki.Stream(), binary_read);
100 Input ki(cur_nnet_rxfilename, &binary_read);
101 trans_model.
Read(ki.Stream(), binary_read);
102 am_cur_nnet.
Read(ki.Stream(), binary_read);
107 KALDI_WARN <<
"Parameter-dim mismatch, cannot equalize the relative " 108 <<
"changes in parameters for each layer.";
120 diff_nnet.ComponentDotProducts(diff_nnet, &relative_diff);
121 relative_diff.ApplyPow(0.5);
125 baseline_prod.ApplyPow(0.5);
126 relative_diff.DivElements(baseline_prod);
127 KALDI_LOG <<
"Relative parameter differences per layer are " 133 for (int32
i = 0;
i < num_updatable;
i++) {
134 if (relative_diff(
i) == 0.0) {
139 BaseFloat average_diff = relative_diff.Sum()
140 /
static_cast<BaseFloat>(num_updatable - num_zero);
141 for (int32
i = 0;
i < num_updatable;
i++) {
142 if (relative_diff(
i) == 0.0) {
143 relative_diff(
i) = average_diff;
146 KALDI_LOG <<
"Zeros detected in the relative parameter difference " 147 <<
"vector, updating the vector to " << relative_diff;
153 cur_nnet_learning_rates(num_updatable);
156 KALDI_LOG <<
"Learning rates for previous model per layer are " 157 << prev_nnet_learning_rates;
158 KALDI_LOG <<
"Learning rates for current model per layer are " 159 << cur_nnet_learning_rates;
163 if (average_learning_rate == 0.0) {
164 target_geometric_mean =
Exp(cur_nnet_learning_rates.SumLog()
165 /
static_cast<BaseFloat>(num_updatable));
167 target_geometric_mean = average_learning_rate;
175 nnet_learning_rates.DivElements(relative_diff);
177 nnet_learning_rates(num_updatable - 1) *= last_layer_factor;
179 nnet_learning_rates(0) *= first_layer_factor;
180 BaseFloat cur_geometric_mean =
Exp(nnet_learning_rates.SumLog()
181 /
static_cast<BaseFloat>(num_updatable));
182 nnet_learning_rates.Scale(target_geometric_mean / cur_geometric_mean);
183 KALDI_LOG <<
"New learning rates for current model per layer are " 184 << nnet_learning_rates;
189 scale_factors.DivElements(prev_nnet_learning_rates);
193 KALDI_LOG <<
"Scale parameter difference retroactively. Scaling factors " 194 <<
"are " << scale_factors;
202 Output ko(modified_cur_nnet_rxfilename, binary_write);
203 trans_model.
Write(ko.Stream(), binary_write);
204 am_cur_nnet.
Write(ko.Stream(), binary_write);
207 }
catch(
const std::exception &e) {
208 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
int32 NumUpdatableComponents() const
Returns the number of updatable components.
void ComponentDotProducts(const Nnet &other, VectorBase< BaseFloat > *dot_prod) const
void Read(std::istream &is, bool binary)
virtual int32 GetParameterDim() const
void GetLearningRates(VectorBase< BaseFloat > *learning_rates) const
Get all the learning rates in the neural net (the output must have dim equal to NumUpdatableComponent...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
void SetMaxChange(BaseFloat max_change, Nnet *nnet)
void ScaleComponents(const VectorBase< BaseFloat > &scales)
Scales the parameters of each of the updatable components.
void Write(std::ostream &os, bool binary) const
A class representing a vector.
void SetLearningRates(BaseFloat learning_rates)
Set all the learning rates in the neural net to this value.
#define KALDI_ASSERT(cond)
const Nnet & GetNnet() const