estimate-am-sgmm2-ebw.h
Go to the documentation of this file.
1 // sgmm2/estimate-am-sgmm2-ebw.h
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_SGMM2_ESTIMATE_AM_SGMM2_EBW_H_
21 #define KALDI_SGMM2_ESTIMATE_AM_SGMM2_EBW_H_ 1
22 
23 #include <string>
24 #include <vector>
25 
26 #include "gmm/model-common.h"
27 #include "itf/options-itf.h"
29 
30 namespace kaldi {
31 
70 
72 
78 
81  tau_v = 50.0;
82  lrate_v = 0.5;
83  tau_M = 500.0;
84  lrate_M = 0.5;
85  tau_N = 500.0;
86  lrate_N = 0.5;
87  tau_c = 10.0;
88  tau_w = 50.0;
89  lrate_w = 1.0;
90  tau_u = 50.0;
91  lrate_u = 1.0;
92  max_impr_u = 0.25;
93  tau_Sigma = 500.0;
94  lrate_Sigma = 0.5;
95 
96  min_substate_weight = 1.0e-05;
97  cov_min_value = 0.5;
98 
99  max_cond = 1.0e+05;
100  epsilon = 1.0e-40;
101  }
102 
103  void Register(OptionsItf *opts) {
104  std::string module = "EbwAmSgmm2Options: ";
105  opts->Register("tau-v", &tau_v, module+
106  "Smoothing constant for phone vector estimation.");
107  opts->Register("lrate-v", &lrate_v, module+
108  "Learning rate constant for phone vector estimation.");
109  opts->Register("tau-m", &tau_M, module+
110  "Smoothing constant for estimation of phonetic-subspace projections (M).");
111  opts->Register("lrate-m", &lrate_M, module+
112  "Learning rate constant for phonetic-subspace projections.");
113  opts->Register("tau-n", &tau_N, module+
114  "Smoothing constant for estimation of speaker-subspace projections (N).");
115  opts->Register("lrate-n", &lrate_N, module+
116  "Learning rate constant for speaker-subspace projections.");
117  opts->Register("tau-c", &tau_c, module+
118  "Smoothing constant for estimation of substate weights (c)");
119  opts->Register("tau-w", &tau_w, module+
120  "Smoothing constant for estimation of phonetic-space weight projections (w)");
121  opts->Register("lrate-w", &lrate_w, module+
122  "Learning rate constant for phonetic-space weight-projections (w)");
123  opts->Register("tau-u", &tau_u, module+
124  "Smoothing constant for estimation of speaker-space weight projections (u)");
125  opts->Register("lrate-u", &lrate_u, module+
126  "Learning rate constant for speaker-space weight-projections (u)");
127  opts->Register("tau-sigma", &tau_Sigma, module+
128  "Smoothing constant for estimation of within-class covariances (Sigma)");
129  opts->Register("lrate-sigma", &lrate_Sigma, module+
130  "Constant that controls speed of learning for variances (larger->slower)");
131  opts->Register("cov-min-value", &cov_min_value, module+
132  "Minimum value that an eigenvalue of the updated covariance matrix can take, "
133  "relative to its old value (maximum is inverse of this.)");
134  opts->Register("min-substate-weight", &min_substate_weight, module+
135  "Floor for weights of sub-states.");
136  opts->Register("max-cond", &max_cond, module+
137  "Value used in handling singular matrices during update.");
138  opts->Register("epsilon", &max_cond, module+
139  "Value used in handling singular matrices during update.");
140  }
141 };
142 
143 
148  public:
149  explicit EbwAmSgmm2Updater(const EbwAmSgmm2Options &options):
150  options_(options) {}
151 
152  void Update(const MleAmSgmm2Accs &num_accs,
153  const MleAmSgmm2Accs &den_accs,
154  AmSgmm2 *model,
155  SgmmUpdateFlagsType flags,
156  BaseFloat *auxf_change_out,
157  BaseFloat *count_out);
158 
159  protected:
160  // The following two classes relate to multi-core parallelization of some
161  // phases of the update.
162  friend class EbwUpdateWClass;
164  private:
166 
168 
169  double UpdatePhoneVectors(const MleAmSgmm2Accs &num_accs,
170  const MleAmSgmm2Accs &den_accs,
171  const std::vector< SpMatrix<double> > &H,
172  AmSgmm2 *model) const;
173 
174  // Called from UpdatePhoneVectors; updates a subset of states
175  // (relates to multi-threading).
176  void UpdatePhoneVectorsInternal(const MleAmSgmm2Accs &num_accs,
177  const MleAmSgmm2Accs &den_accs,
178  const std::vector<SpMatrix<double> > &H,
179  AmSgmm2 *model,
180  double *auxf_impr,
181  int32 num_threads,
182  int32 thread_id) const;
183  // Called from UpdatePhoneVectorsInternal
184  static void ComputePhoneVecStats(const MleAmSgmm2Accs &accs,
185  const AmSgmm2 &model,
186  const std::vector<SpMatrix<double> > &H,
187  int32 j1,
188  int32 m,
189  const Vector<double> &w_jm,
190  double gamma_jm,
191  Vector<double> *g_jm,
192  SpMatrix<double> *H_jm);
193 
194  double UpdateM(const MleAmSgmm2Accs &num_accs,
195  const MleAmSgmm2Accs &den_accs,
196  const std::vector< SpMatrix<double> > &Q_num,
197  const std::vector< SpMatrix<double> > &Q_den,
198  const Vector<double> &gamma_num,
199  const Vector<double> &gamma_den,
200  AmSgmm2 *model) const;
201 
202  double UpdateN(const MleAmSgmm2Accs &num_accs,
203  const MleAmSgmm2Accs &den_accs,
204  const Vector<double> &gamma_num,
205  const Vector<double> &gamma_den,
206  AmSgmm2 *model) const;
207 
208  double UpdateVars(const MleAmSgmm2Accs &num_accs,
209  const MleAmSgmm2Accs &den_accs,
210  const Vector<double> &gamma_num,
211  const Vector<double> &gamma_den,
212  const std::vector< SpMatrix<double> > &S_means,
213  AmSgmm2 *model) const;
214 
217  double UpdateW(const MleAmSgmm2Accs &num_accs,
218  const MleAmSgmm2Accs &den_accs,
219  const Vector<double> &gamma_num,
220  const Vector<double> &gamma_den,
221  AmSgmm2 *model);
222 
223 
224  double UpdateU(const MleAmSgmm2Accs &num_accs,
225  const MleAmSgmm2Accs &den_accs,
226  const Vector<double> &gamma_num,
227  const Vector<double> &gamma_den,
228  AmSgmm2 *model);
229 
230  double UpdateSubstateWeights(const MleAmSgmm2Accs &num_accs,
231  const MleAmSgmm2Accs &den_accs,
232  AmSgmm2 *model);
233 
235  EbwAmSgmm2Updater() {} // Prevent unconfigured updater.
236 };
237 
238 
239 } // namespace kaldi
240 
241 
242 #endif // KALDI_SGMM2_ESTIMATE_AM_SGMM2_EBW_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
This header implements a form of Extended Baum-Welch training for SGMMs.
BaseFloat tau_Sigma
Tau value for smoothing covariance-matrices Sigma.
BaseFloat max_cond
is allowed to change.
BaseFloat max_impr_u
Maximum improvement/frame allowed for u [0.25, carried over from ML update.].
BaseFloat tau_u
Tau value for smoothing update of speaker-subspace weight projectsions (u)
BaseFloat lrate_Sigma
Learning rate used in updating Sigma– default 0.5.
BaseFloat lrate_v
Learning rate used in updating v– default 0.5.
BaseFloat lrate_w
Learning rate used in updating w– default 1.0.
BaseFloat lrate_N
Learning rate used in updating N– default 0.5.
BaseFloat tau_v
Smoothing constant for updates of sub-state vectors v_{jm}.
kaldi::int32 int32
BaseFloat lrate_u
Learning rate used in updating u– default 1.0.
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
BaseFloat tau_M
Smoothing constant for the M quantities (phone-subspace projections)
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
BaseFloat lrate_M
Learning rate used in updating M– default 0.5.
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:59
BaseFloat tau_c
Tau value for smoothing substate weights (c)
BaseFloat tau_N
Smoothing constant for the N quantities (speaker-subspace projections)
EbwAmSgmm2Updater(const EbwAmSgmm2Options &options)
BaseFloat tau_w
Tau value for smoothing update of phonetic-subspace weight projectsions (w)
Vector< double > gamma_j_
State occupancies.
EbwAmSgmm2Options()
for an issue in some implementations of SVD.
void Register(OptionsItf *opts)
BaseFloat epsilon
very small value used in SolveQuadraticProblem; workaround
BaseFloat min_substate_weight
Minimum allowed weight in a sub-state.
Class for the accumulators associated with the phonetic-subspace model parameters.