gmm-global-est-fmllr.cc
Go to the documentation of this file.
1 // gmmbin/gmm-global-est-fmllr.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation; Saarland University
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <string>
21 using std::string;
22 #include <vector>
23 using std::vector;
24 
25 #include "base/kaldi-common.h"
26 #include "util/common-utils.h"
27 #include "gmm/am-diag-gmm.h"
28 #include "hmm/transition-model.h"
30 
31 namespace kaldi {
33  const DiagGmm &gmm,
34  const std::string &key,
35  RandomAccessBaseFloatVectorReader *weights_reader,
37  AccumFullGmm *fullcov_stats) {
38  Vector<BaseFloat> weights;
39  if (weights_reader->IsOpen()) {
40  if (!weights_reader->HasKey(key)) {
41  KALDI_WARN << "No weights present for utterance " << key;
42  return false;
43  }
44  weights = weights_reader->Value(key);
45  }
46  int32 num_frames = feats.NumRows();
47  if (gselect_reader->IsOpen()) {
48  if (!gselect_reader->HasKey(key)) {
49  KALDI_WARN << "No gselect information present for utterance " << key;
50  return false;
51  }
52  const std::vector<std::vector<int32> > &gselect(gselect_reader->Value(key));
53  if (gselect.size() != num_frames) {
54  KALDI_WARN << "gselect information has wrong size for utterance " << key;
55  return false;
56  }
57  for (int32 t = 0; t < num_frames; t++) {
58  const std::vector<int32> &this_gselect(gselect[t]);
59  BaseFloat weight = (weights.Dim() != 0 ? weights(t) : 1.0);
60  if (weight != 0.0) {
61  Vector<BaseFloat> post(this_gselect.size());
62  gmm.LogLikelihoodsPreselect(feats.Row(t), this_gselect, &post);
63  post.ApplySoftMax(); // get posteriors.
64  post.Scale(weight); // scale by the weight for this frame.
65  for (size_t i = 0; i < this_gselect.size(); i++)
66  fullcov_stats->AccumulateForComponent(feats.Row(t),
67  this_gselect[i], post(i));
68  }
69  }
70  } else {
71  for (int32 t = 0; t < num_frames; t++) {
72  BaseFloat weight = (weights.Dim() != 0 ? weights(t) : 1.0);
73  if (weight != 0.0)
74  fullcov_stats->AccumulateFromDiag(gmm, feats.Row(t), weight);
75  }
76  }
77  return true;
78 }
79 
80 
81 }
82 
83 int main(int argc, char *argv[]) {
84  try {
85  typedef kaldi::int32 int32;
86  using namespace kaldi;
87  const char *usage =
88  "Estimate global fMLLR transforms, either per utterance or for the supplied\n"
89  "set of speakers (spk2utt option). Reads features, and (with --weights option)\n"
90  "weights for each frame (also see --gselect option)\n"
91  "Usage: gmm-global-est-fmllr [options] <gmm-in> <feature-rspecifier> <transform-wspecifier>\n";
92 
93  ParseOptions po(usage);
94  FmllrOptions fmllr_opts;
95  string spk2utt_rspecifier, gselect_rspecifier, weights_rspecifier,
96  alignment_model;
97 
98 
99  po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to "
100  "utterance-list map");
101  po.Register("gselect", &gselect_rspecifier, "rspecifier for gselect objects "
102  "to limit the #Gaussians accessed on each frame.");
103  po.Register("weights", &weights_rspecifier, "rspecifier for a vector of floats "
104  "for each utterance, that's a per-frame weight.");
105  po.Register("align-model", &alignment_model, "rxfilename for a model in the "
106  "speaker-independent space, to get Gaussian alignments from");
107 
108  fmllr_opts.Register(&po);
109 
110  po.Read(argc, argv);
111 
112  if (po.NumArgs() != 3) {
113  po.PrintUsage();
114  exit(1);
115  }
116 
117  string gmm_rxfilename = po.GetArg(1),
118  feature_rspecifier = po.GetArg(2),
119  trans_wspecifier = po.GetArg(3);
120 
121  DiagGmm gmm;
122  ReadKaldiObject(gmm_rxfilename, &gmm);
123  DiagGmm ali_gmm_read;
124  if (alignment_model != "") {
125  bool binary;
126  Input ki(gmm_rxfilename, &binary);
127  ali_gmm_read.Read(ki.Stream(), binary);
128  }
129  DiagGmm &ali_gmm = (alignment_model != "" ? ali_gmm_read : gmm);
130 
131  RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier);
132  RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
133 
134  double tot_impr = 0.0, tot_t = 0.0;
135 
136  BaseFloatMatrixWriter transform_writer(trans_wspecifier);
137 
138  int32 num_done = 0, num_err = 0;
139 
140  if (spk2utt_rspecifier != "") { // per-speaker adaptation
141  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
142  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
143 
144  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
145  AccumFullGmm fullcov_stats(gmm.NumGauss(), gmm.Dim(), kGmmAll);
146  string spk = spk2utt_reader.Key();
147  const vector<string> &uttlist = spk2utt_reader.Value();
148  for (size_t i = 0; i < uttlist.size(); i++) {
149  std::string utt = uttlist[i];
150  if (!feature_reader.HasKey(utt)) {
151  KALDI_WARN << "Did not find features for utterance " << utt;
152  continue;
153  }
154  const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
155 
156  if (AccumulateForUtterance(feats, ali_gmm, utt, &weights_reader,
157  &gselect_reader, &fullcov_stats)) num_done++;
158  else num_err++;
159  } // end looping over all utterances of the current speaker
160 
161  BaseFloat impr, spk_tot_t;
162  { // Compute the transform and write it out.
163  Matrix<BaseFloat> transform(gmm.Dim(), gmm.Dim()+1);
164  transform.SetUnit();
165  FmllrDiagGmmAccs spk_stats(gmm, fullcov_stats);
166  spk_stats.Update(fmllr_opts, &transform, &impr, &spk_tot_t);
167  transform_writer.Write(spk, transform);
168  }
169  KALDI_LOG << "For speaker " << spk << ", auxf-impr from fMLLR is "
170  << (impr/spk_tot_t) << ", over " << spk_tot_t << " frames.";
171  tot_impr += impr;
172  tot_t += spk_tot_t;
173  } // end looping over speakers
174  } else { // per-utterance adaptation
175  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
176  for (; !feature_reader.Done(); feature_reader.Next()) {
177  string utt = feature_reader.Key();
178 
179  const Matrix<BaseFloat> &feats = feature_reader.Value();
180 
181  AccumFullGmm fullcov_stats(gmm.NumGauss(), gmm.Dim(), kGmmAll);
182 
183  if (AccumulateForUtterance(feats, ali_gmm, utt, &weights_reader,
184  &gselect_reader, &fullcov_stats)) {
185  BaseFloat impr, utt_tot_t;
186  { // Compute the transform and write it out.
187  Matrix<BaseFloat> transform(gmm.Dim(), gmm.Dim()+1);
188  transform.SetUnit();
189  FmllrDiagGmmAccs spk_stats(gmm, fullcov_stats);
190  spk_stats.Update(fmllr_opts, &transform, &impr, &utt_tot_t);
191  transform_writer.Write(utt, transform);
192  }
193  KALDI_LOG << "For utterance " << utt << ", auxf-impr from fMLLR is "
194  << (impr/utt_tot_t) << ", over " << utt_tot_t << " frames.";
195  tot_impr += impr;
196  tot_t += utt_tot_t;
197  num_done++;
198  } else num_err++;
199 
200  }
201  }
202 
203  KALDI_LOG << "Done " << num_done << " files, " << num_err
204  << " with errors.";
205  KALDI_LOG << "Overall fMLLR auxf impr per frame is "
206  << (tot_impr / tot_t) << " over " << tot_t << " frames.";
207  return (num_done != 0 ? 0 : 1);
208  } catch(const std::exception &e) {
209  std::cerr << e.what();
210  return -1;
211  }
212 }
213 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: diag-gmm.h:74
int main(int argc, char *argv[])
void LogLikelihoodsPreselect(const VectorBase< BaseFloat > &data, const std::vector< int32 > &indices, Vector< BaseFloat > *loglikes) const
Outputs the per-component log-likelihoods of a subset of mixture components.
Definition: diag-gmm.cc:566
BaseFloat AccumulateFromDiag(const DiagGmm &gmm, const VectorBase< BaseFloat > &data, BaseFloat frame_posterior)
Accumulate for all components given a diagonal-covariance GMM.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This does not work with multiple feature transforms.
void AccumulateForUtterance(const Matrix< BaseFloat > &feats, const GaussPost &gpost, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void SetUnit()
Sets to zero, except ones along diagonal [for non-square matrices too].
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
Class for computing the maximum-likelihood estimates of the parameters of a Gaussian mixture model...
Definition: mle-full-gmm.h:74
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: diag-gmm.h:72
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
bool HasKey(const std::string &key)
void Register(OptionsItf *opts)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Read(std::istream &in, bool binary)
Definition: diag-gmm.cc:728
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
#define KALDI_LOG
Definition: kaldi-error.h:153
void AccumulateForComponent(const VectorBase< BaseFloat > &data, int32 comp_index, BaseFloat weight)
Accumulate for a single component, given the posterior.
Definition: mle-full-gmm.cc:96
void Update(const FmllrOptions &opts, MatrixBase< BaseFloat > *fmllr_mat, BaseFloat *objf_impr, BaseFloat *count)
Update.