41 int main(
int argc,
char *argv[]) {
43 using namespace kaldi;
46 typedef kaldi::int64 int64;
49 "This program modifies the learning rates so as to equalize the\n" 50 "relative changes in parameters for each layer, while keeping their\n" 51 "geometric mean the same (or changing it to a value specified using\n" 52 "the --average-learning-rate option).\n" 54 "Usage: nnet-modify-learning-rates [options] <prev-model> \\\n" 55 " <cur-model> <modified-cur-model>\n" 56 "e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n" 57 " 5.mdl 6.mdl 6.mdl\n";
59 bool binary_write =
true;
60 bool retroactive =
false;
66 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
67 po.
Register(
"average-learning-rate", &average_learning_rate,
68 "If supplied, change learning rate geometric mean to the given " 70 po.
Register(
"first-layer-factor", &first_layer_factor,
"Factor that " 71 "reduces the target relative learning rate for first layer.");
72 po.
Register(
"last-layer-factor", &last_layer_factor,
"Factor that " 73 "reduces the target relative learning rate for last layer.");
74 po.
Register(
"retroactive", &retroactive,
"If true, scale the parameter " 75 "differences as well.");
86 std::string prev_nnet_rxfilename = po.
GetArg(1),
87 cur_nnet_rxfilename = po.
GetArg(2),
88 modified_cur_nnet_rxfilename = po.
GetOptArg(3);
91 AmNnet am_prev_nnet, am_cur_nnet;
94 Input ki(prev_nnet_rxfilename, &binary_read);
100 Input ki(cur_nnet_rxfilename, &binary_read);
107 KALDI_WARN <<
"Parameter-dim mismatch, cannot equalize the relative " 108 <<
"changes in parameters for each layer.";
120 diff_nnet.ComponentDotProducts(diff_nnet, &relative_diff);
127 KALDI_LOG <<
"Relative parameter differences per layer are " 133 for (int32
i = 0;
i < num_updatable;
i++) {
134 if (relative_diff(
i) == 0.0) {
139 BaseFloat average_diff = relative_diff.Sum()
140 /
static_cast<BaseFloat>(num_updatable - num_zero);
141 for (int32
i = 0;
i < num_updatable;
i++) {
142 if (relative_diff(
i) == 0.0) {
143 relative_diff(
i) = average_diff;
146 KALDI_LOG <<
"Zeros detected in the relative parameter difference " 147 <<
"vector, updating the vector to " << relative_diff;
153 cur_nnet_learning_rates(num_updatable);
156 KALDI_LOG <<
"Learning rates for previous model per layer are " 157 << prev_nnet_learning_rates;
158 KALDI_LOG <<
"Learning rates for current model per layer are " 159 << cur_nnet_learning_rates;
163 if (average_learning_rate == 0.0) {
164 target_geometric_mean =
Exp(cur_nnet_learning_rates.SumLog()
165 /
static_cast<BaseFloat>(num_updatable));
167 target_geometric_mean = average_learning_rate;
177 nnet_learning_rates(num_updatable - 1) *= last_layer_factor;
179 nnet_learning_rates(0) *= first_layer_factor;
181 /
static_cast<BaseFloat>(num_updatable));
182 nnet_learning_rates.
Scale(target_geometric_mean / cur_geometric_mean);
183 KALDI_LOG <<
"New learning rates for current model per layer are " 184 << nnet_learning_rates;
189 scale_factors.
DivElements(prev_nnet_learning_rates);
193 KALDI_LOG <<
"Scale parameter difference retroactively. Scaling factors " 194 <<
"are " << scale_factors;
202 Output ko(modified_cur_nnet_rxfilename, binary_write);
207 }
catch(
const std::exception &e) {
208 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
const Component & GetComponent(int32 c) const
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
int32 NumUpdatableComponents() const
Returns the number of updatable components.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.
void ComponentDotProducts(const Nnet &other, VectorBase< BaseFloat > *dot_prod) const
Real SumLog() const
Returns sum of the logs of the elements.
void Read(std::istream &is, bool binary)
int main(int argc, char *argv[])
virtual int32 GetParameterDim() const
void Register(const std::string &name, bool *ptr, const std::string &doc)
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
void SetMaxChange(BaseFloat max_change)
void GetLearningRates(VectorBase< BaseFloat > *learning_rates) const
Get all the learning rates in the neural net (the output must have dim equal to NumUpdatableComponent...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void Scale(Real alpha)
Multiplies all elements by this constant.
void SetMaxChange(BaseFloat max_change, Nnet *nnet)
void ScaleComponents(const VectorBase< BaseFloat > &scales)
Scales the parameters of each of the updatable components.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
A class representing a vector.
void SetLearningRates(BaseFloat learning_rates)
Set all the learning rates in the neural net to this value.
#define KALDI_ASSERT(cond)
void ApplyPow(Real power)
Take all elements of vector to a power.
void DivElements(const VectorBase< Real > &v)
Divide element-by-element by a vector.
const Nnet & GetNnet() const
std::string GetOptArg(int param) const