nnet-set-learnrate.cc
Go to the documentation of this file.
1 // nnetbin/nnet-set-learnrate.cc
2 
3 // Copyright 2016, Brno University of Technology
4 // (author: Katerina Zmolikova, Karel Vesely)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "util/common-utils.h"
22 #include "nnet/nnet-nnet.h"
23 #include "nnet/nnet-component.h"
25 #include "nnet/nnet-activation.h"
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet1;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Sets learning rate coefficient inside of 'nnet1' model\n"
35  "Usage: nnet-set-learnrate --components=<csl> --coef=<float> <nnet-in> <nnet-out>\n"
36  "e.g.: nnet-set-learnrate --components=1:3:5 --coef=0.5 --bias-coef=0.1 nnet-in nnet-out\n";
37 
38  ParseOptions po(usage);
39  bool binary = true;
40  po.Register("binary", &binary, "Write output in binary mode");
41 
42  std::string components_str = "";
43  po.Register("components", &components_str,
44  "Select components by 'csl' of 1..N values. Layout is the same as in "
45  "'nnet-info' output, (example 1:3:5)");
46 
47  float coef = 1.0,
48  weight_coef = 1.0,
49  bias_coef = 1.0;
50 
51  po.Register("coef", &coef,
52  "Learn-rate coefficient for both weight matrices and biases.");
53  po.Register("weight-coef", &weight_coef,
54  "Learn-rate coefficient for weight matrices "
55  "(used as: coef * weight_coef).");
56  po.Register("bias-coef", &bias_coef,
57  "Learn-rate coefficient for bias (used as: coef * bias_coef).");
58 
59  po.Read(argc, argv);
60 
61  if (po.NumArgs() != 2) {
62  po.PrintUsage();
63  exit(1);
64  }
65 
66  std::string nnet_in_filename = po.GetArg(1),
67  nnet_out_filename = po.GetArg(2);
68 
69  Nnet nnet;
70  nnet.Read(nnet_in_filename);
71 
72  // A vector which contains indices of components,
73  // where we will set the 'learn-rate coefficients',
74  std::vector<int32> components;
75  if (components_str != "") {
76  // components were selected by the option,
77  kaldi::SplitStringToIntegers(components_str, ":", false, &components);
78  } else {
79  // otherwise select all the components (1..Ncomp),
80  for (int32 i = 1; i <= nnet.NumComponents(); i++) {
81  components.push_back(i);
82  }
83  }
84 
85  // Setting the learning rate coefficients,
86  for (int32 i = 0; i < components.size(); i++) {
87  if (nnet.GetComponent(components[i]-1).IsUpdatable()) {
88  UpdatableComponent& comp =
89  dynamic_cast<UpdatableComponent&>(nnet.GetComponent(components[i]-1));
90  comp.SetLearnRateCoef(coef * weight_coef); // weight matrices, etc.,
91  comp.SetBiasLearnRateCoef(coef * bias_coef); // biases,
92  }
93  }
94 
95  // Write the 'nnet1' network,
96  nnet.Write(nnet_out_filename, binary);
97 
98  return 0;
99  } catch(const std::exception &e) {
100  std::cerr << e.what();
101  return -1;
102  }
103 }
104 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 NumComponents() const
Returns the number of &#39;Components&#39; which form the NN.
Definition: nnet-nnet.h:66
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hype...
kaldi::int32 int32
virtual void SetLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient,.
void Register(const std::string &name, bool *ptr, const std::string &doc)
int main(int argc, char *argv[])
virtual bool IsUpdatable() const
Check if componeny has &#39;Updatable&#39; interface (trainable components),.
virtual void SetBiasLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient for bias,.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153