model-common.cc
Go to the documentation of this file.
1 // gmm/model-common.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "matrix/matrix-lib.h"
21 #include "gmm/model-common.h"
22 #include <queue>
23 #include <numeric>
24 
25 namespace kaldi {
26 GmmFlagsType StringToGmmFlags(std::string str) {
27  GmmFlagsType flags = 0;
28  for (const char *c = str.c_str(); *c != '\0'; c++) {
29  switch (*c) {
30  case 'm': flags |= kGmmMeans; break;
31  case 'v': flags |= kGmmVariances; break;
32  case 'w': flags |= kGmmWeights; break;
33  case 't': flags |= kGmmTransitions; break;
34  case 'a': flags |= kGmmAll; break;
35  default: KALDI_ERR << "Invalid element " << CharToString(*c)
36  << " of GmmFlagsType option string "
37  << str;
38  }
39  }
40  return flags;
41 }
42 
43 std::string GmmFlagsToString(GmmFlagsType flags) {
44  std::string ans;
45  if (flags & kGmmMeans) ans += "m";
46  if (flags & kGmmVariances) ans += "v";
47  if (flags & kGmmWeights) ans += "w";
48  if (flags & kGmmTransitions) ans += "t";
49  return ans;
50 }
51 
53 KALDI_ASSERT((flags & ~kGmmAll) == 0); // make sure only valid flags are present.
54  if (flags & kGmmVariances) flags |= kGmmMeans;
55  if (flags & kGmmMeans) flags |= kGmmWeights;
56  if (!(flags & kGmmWeights)) {
57  KALDI_WARN << "Adding in kGmmWeights (\"w\") to empty flags.";
58  flags |= kGmmWeights; // Just add this in regardless:
59  // if user wants no stats, this will stop programs from crashing due to dim mismatches.
60  }
61  return flags;
62 }
63 
65  SgmmUpdateFlagsType flags = 0;
66  for (const char *c = str.c_str(); *c != '\0'; c++) {
67  switch (*c) {
68  case 'v': flags |= kSgmmPhoneVectors; break;
69  case 'M': flags |= kSgmmPhoneProjections; break;
70  case 'w': flags |= kSgmmPhoneWeightProjections; break;
71  case 'S': flags |= kSgmmCovarianceMatrix; break;
72  case 'c': flags |= kSgmmSubstateWeights; break;
73  case 'N': flags |= kSgmmSpeakerProjections; break;
74  case 't': flags |= kSgmmTransitions; break;
75  case 'u': flags |= kSgmmSpeakerWeightProjections; break;
76  case 'a': flags |= kSgmmAll; break;
77  default: KALDI_ERR << "Invalid element " << CharToString(*c)
78  << " of SgmmUpdateFlagsType option string "
79  << str;
80  }
81  }
82  return flags;
83 }
84 
85 
87  SgmmWriteFlagsType flags = 0;
88  for (const char *c = str.c_str(); *c != '\0'; c++) {
89  switch (*c) {
90  case 'g': flags |= kSgmmGlobalParams; break;
91  case 's': flags |= kSgmmStateParams; break;
92  case 'n': flags |= kSgmmNormalizers; break;
93  case 'u': flags |= kSgmmBackgroundGmms; break;
94  case 'a': flags |= kSgmmAll; break;
95  default: KALDI_ERR << "Invalid element " << CharToString(*c)
96  << " of SgmmWriteFlagsType option string "
97  << str;
98  }
99  }
100  return flags;
101 }
102 
103 struct CountStats {
105  : pdf_index(p), num_components(n), occupancy(occ) {}
109  bool operator < (const CountStats &other) const {
110  return occupancy/(num_components+1.0e-10) <
111  other.occupancy/(other.num_components+1.0e-10);
112  }
113 };
114 
115 
116 void GetSplitTargets(const Vector<BaseFloat> &state_occs,
117  int32 target_components,
118  BaseFloat power,
119  BaseFloat min_count,
120  std::vector<int32> *targets) {
121  std::priority_queue<CountStats> split_queue;
122  int32 num_pdfs = state_occs.Dim();
123 
124  for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) {
125  BaseFloat occ = pow(state_occs(pdf_index), power);
126  // initialize with one Gaussian per PDF, to put a floor
127  // of 1 on the #Gauss
128  split_queue.push(CountStats(pdf_index, 1, occ));
129  }
130 
131  for (int32 num_gauss = num_pdfs; num_gauss < target_components;) {
132  CountStats state_to_split = split_queue.top();
133  if (state_to_split.occupancy == 0) {
134  KALDI_WARN << "Could not split up to " << target_components
135  << " due to min-count = " << min_count
136  << " (or no counts at all)\n";
137  break;
138  }
139  split_queue.pop();
140  BaseFloat orig_occ = state_occs(state_to_split.pdf_index);
141  if ((state_to_split.num_components+1) * min_count >= orig_occ) {
142  state_to_split.occupancy = 0; // min-count active -> disallow splitting
143  // this state any more by setting occupancy = 0.
144  } else {
145  state_to_split.num_components++;
146  num_gauss++;
147  }
148  split_queue.push(state_to_split);
149  }
150  targets->resize(num_pdfs);
151  while (!split_queue.empty()) {
152  int32 pdf_index = split_queue.top().pdf_index;
153  int32 pdf_tgt_comp = split_queue.top().num_components;
154  (*targets)[pdf_index] = pdf_tgt_comp;
155  split_queue.pop();
156  }
157 }
158 
159 } // End namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
uint16 SgmmWriteFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:70
GmmFlagsType AugmentGmmFlags(GmmFlagsType f)
Returns "augmented" version of flags: e.g.
Definition: model-common.cc:52
GmmFlagsType StringToGmmFlags(std::string str)
Convert string which is some subset of "mSwa" to flags.
Definition: model-common.cc:26
CountStats(int32 p, int32 n, BaseFloat occ)
bool operator<(const CountStats &other) const
kaldi::int32 int32
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
Definition: model-common.cc:64
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void GetSplitTargets(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count, std::vector< int32 > *targets)
Get Gaussian-mixture or substate-mixture splitting targets, according to a power rule (e...
t .. not really part of SGMM.
Definition: model-common.h:55
float BaseFloat
Definition: kaldi-types.h:29
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:59
The letters correspond to the variable names.
Definition: model-common.h:48
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
std::string CharToString(const char &c)
Definition: kaldi-utils.cc:36
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::string GmmFlagsToString(GmmFlagsType flags)
Convert GMM flags to string.
Definition: model-common.cc:43
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str)
Definition: model-common.cc:86
u [ for SSGMM ]
Definition: model-common.h:56