nnet-set-learnrate.cc File Reference
Include dependency graph for nnet-set-learnrate.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file nnet-set-learnrate.cc.

References ParseOptions::GetArg(), Nnet::GetComponent(), rnnlm::i, Component::IsUpdatable(), ParseOptions::NumArgs(), Nnet::NumComponents(), ParseOptions::PrintUsage(), ParseOptions::Read(), Nnet::Read(), ParseOptions::Register(), UpdatableComponent::SetBiasLearnRateCoef(), UpdatableComponent::SetLearnRateCoef(), kaldi::SplitStringToIntegers(), and Nnet::Write().

27  {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet1;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Sets learning rate coefficient inside of 'nnet1' model\n"
35  "Usage: nnet-set-learnrate --components=<csl> --coef=<float> <nnet-in> <nnet-out>\n"
36  "e.g.: nnet-set-learnrate --components=1:3:5 --coef=0.5 --bias-coef=0.1 nnet-in nnet-out\n";
37 
38  ParseOptions po(usage);
39  bool binary = true;
40  po.Register("binary", &binary, "Write output in binary mode");
41 
42  std::string components_str = "";
43  po.Register("components", &components_str,
44  "Select components by 'csl' of 1..N values. Layout is the same as in "
45  "'nnet-info' output, (example 1:3:5)");
46 
47  float coef = 1.0,
48  weight_coef = 1.0,
49  bias_coef = 1.0;
50 
51  po.Register("coef", &coef,
52  "Learn-rate coefficient for both weight matrices and biases.");
53  po.Register("weight-coef", &weight_coef,
54  "Learn-rate coefficient for weight matrices "
55  "(used as: coef * weight_coef).");
56  po.Register("bias-coef", &bias_coef,
57  "Learn-rate coefficient for bias (used as: coef * bias_coef).");
58 
59  po.Read(argc, argv);
60 
61  if (po.NumArgs() != 2) {
62  po.PrintUsage();
63  exit(1);
64  }
65 
66  std::string nnet_in_filename = po.GetArg(1),
67  nnet_out_filename = po.GetArg(2);
68 
69  Nnet nnet;
70  nnet.Read(nnet_in_filename);
71 
72  // A vector which contains indices of components,
73  // where we will set the 'learn-rate coefficients',
74  std::vector<int32> components;
75  if (components_str != "") {
76  // components were selected by the option,
77  kaldi::SplitStringToIntegers(components_str, ":", false, &components);
78  } else {
79  // otherwise select all the components (1..Ncomp),
80  for (int32 i = 1; i <= nnet.NumComponents(); i++) {
81  components.push_back(i);
82  }
83  }
84 
85  // Setting the learning rate coefficients,
86  for (int32 i = 0; i < components.size(); i++) {
87  if (nnet.GetComponent(components[i]-1).IsUpdatable()) {
88  UpdatableComponent& comp =
89  dynamic_cast<UpdatableComponent&>(nnet.GetComponent(components[i]-1));
90  comp.SetLearnRateCoef(coef * weight_coef); // weight matrices, etc.,
91  comp.SetBiasLearnRateCoef(coef * bias_coef); // biases,
92  }
93  }
94 
95  // Write the 'nnet1' network,
96  nnet.Write(nnet_out_filename, binary);
97 
98  return 0;
99  } catch(const std::exception &e) {
100  std::cerr << e.what();
101  return -1;
102  }
103 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 NumComponents() const
Returns the number of &#39;Components&#39; which form the NN.
Definition: nnet-nnet.h:66
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hype...
kaldi::int32 int32
virtual void SetLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient,.
virtual bool IsUpdatable() const
Check if componeny has &#39;Updatable&#39; interface (trainable components),.
virtual void SetBiasLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient for bias,.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153