nnet-trnopts.h
Go to the documentation of this file.
1 // nnet/nnet-trnopts.h
2 
3 // Copyright 2013 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET_NNET_TRNOPTS_H_
21 #define KALDI_NNET_NNET_TRNOPTS_H_
22 
23 #include "base/kaldi-common.h"
24 #include "itf/options-itf.h"
25 
26 namespace kaldi {
27 namespace nnet1 {
28 
29 
31  // option declaration
36 
37  // default values
39  learn_rate(0.008),
40  momentum(0.0),
41  l2_penalty(0.0),
42  l1_penalty(0.0)
43  { }
44 
45  // register options
46  void Register(OptionsItf *opts) {
47  opts->Register("learn-rate", &learn_rate, "Learning rate");
48  opts->Register("momentum", &momentum, "Momentum");
49  opts->Register("l2-penalty", &l2_penalty, "L2 penalty (weight decay)");
50  opts->Register("l1-penalty", &l1_penalty, "L1 penalty (promote sparsity)");
51  }
52 
53  // print for debug purposes
54  friend std::ostream& operator<<(std::ostream& os, const NnetTrainOptions& opts) {
55  os << "NnetTrainOptions : "
56  << "learn_rate" << opts.learn_rate << ", "
57  << "momentum" << opts.momentum << ", "
58  << "l2_penalty" << opts.l2_penalty << ", "
59  << "l1_penalty" << opts.l1_penalty;
60  return os;
61  }
62 };
63 
64 
66  // option declaration
73 
74  // default values
76  learn_rate(0.4),
77  momentum(0.5),
78  momentum_max(0.9),
79  momentum_steps(40),
80  momentum_step_period(500000),
81  // 500000 * 40 = 55h of linear increase of momentum
82  l2_penalty(0.0002)
83  { }
84 
85  // register options
86  void Register(OptionsItf *opts) {
87  opts->Register("learn-rate", &learn_rate, "Learning rate");
88 
89  opts->Register("momentum", &momentum,
90  "Initial momentum for linear scheduling");
91  opts->Register("momentum-max", &momentum_max,
92  "Final momentum for linear scheduling");
93  opts->Register("momentum-steps", &momentum_steps,
94  "Number of steps of linear momentum scheduling");
95  opts->Register("momentum-step-period", &momentum_step_period,
96  "Number of datapoints per single momentum increase step");
97 
98  opts->Register("l2-penalty", &l2_penalty,
99  "L2 penalty (weight decay, increases mixing-rate)");
100  }
101 
102  // print for debug purposes
103  friend std::ostream& operator<<(std::ostream& os, const RbmTrainOptions& opts) {
104  os << "RbmTrainOptions : "
105  << "learn_rate" << opts.learn_rate << ", "
106  << "momentum" << opts.momentum << ", "
107  << "momentum_max" << opts.momentum_max << ", "
108  << "momentum_steps" << opts.momentum_steps << ", "
109  << "momentum_step_period" << opts.momentum_step_period << ", "
110  << "l2_penalty" << opts.l2_penalty;
111  return os;
112  }
113 }; // struct RbmTrainOptions
114 
115 } // namespace nnet1
116 } // namespace kaldi
117 
118 #endif // KALDI_NNET_NNET_TRNOPTS_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:86
friend std::ostream & operator<<(std::ostream &os, const NnetTrainOptions &opts)
Definition: nnet-trnopts.h:54
friend std::ostream & operator<<(std::ostream &os, const RbmTrainOptions &opts)
Definition: nnet-trnopts.h:103
void Register(OptionsItf *opts)
Definition: nnet-trnopts.h:46