All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
fgmm-global-init-from-accs.cc File Reference
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/full-gmm.h"
#include "gmm/mle-full-gmm.h"
Include dependency graph for fgmm-global-init-from-accs.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file fgmm-global-init-from-accs.cc.

References SpMatrix< Real >::AddDiagVec(), SpMatrix< Real >::AddVec2(), SpMatrix< Real >::ApplyFloor(), FullGmm::ComputeGconsts(), SpMatrix< Real >::CopyFromSp(), VectorBase< Real >::CopyRowFromMat(), MatrixBase< Real >::CopyRowFromVec(), kaldi::diag, VectorBase< Real >::Dim(), ParseOptions::GetArg(), rnnlm::i, SpMatrix< Real >::InvertDouble(), KALDI_LOG, KALDI_WARN, kaldi::kSetZero, MleFullGmmOptions::max_condition, SpMatrix< Real >::MaxAbsEig(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), MleFullGmmOptions::Register(), ParseOptions::Register(), PackedMatrix< Real >::Scale(), VectorBase< Real >::Scale(), FullGmm::SetInvCovarsAndMeans(), VectorBase< Real >::SetRandn(), FullGmm::SetWeights(), Input::Stream(), VectorBase< Real >::Sum(), MleFullGmmOptions::variance_floor, and kaldi::WriteKaldiObject().

27  {
28  try {
29  using namespace kaldi;
30  typedef int32 int32;
31  MleFullGmmOptions gmm_opts;
32 
33  const char *usage =
34  "Initialize a full-covariance GMM from the accumulated stats.\n"
35  "This binary is similar to fgmm-global-est, but does not use "
36  "a preexisting model. See also fgmm-global-est.\n"
37  "Usage: fgmm-global-init-from-accs [options] <stats-in> "
38  "<number-of-components> <model-out>\n";
39 
40  bool binary_write = true;
41  ParseOptions po(usage);
42  po.Register("binary", &binary_write, "Write output in binary mode");
43  gmm_opts.Register(&po);
44 
45  po.Read(argc, argv);
46 
47  if (po.NumArgs() != 3) {
48  po.PrintUsage();
49  exit(1);
50  }
51 
52  std::string stats_filename = po.GetArg(1),
53  model_out_filename = po.GetArg(3);
54  int32 num_components = atoi(po.GetArg(2).c_str());
55 
56  AccumFullGmm gmm_accs;
57  {
58  bool binary;
59  Input ki(stats_filename, &binary);
60  gmm_accs.Read(ki.Stream(), binary, true /* add accs. */);
61  }
62 
63  int32 num_gauss = gmm_accs.NumGauss(), dim = gmm_accs.Dim(),
64  tot_floored = 0, gauss_floored = 0, tot_low_occ = 0;
65 
66  FullGmm fgmm(num_components, dim);
67 
68  Vector<BaseFloat> weights(num_gauss);
69  Matrix<BaseFloat> means(num_gauss, dim);
70  std::vector<SpMatrix<BaseFloat> > invcovars;
71 
72  for (int32 i = 0; i < num_components; i++) {
73  BaseFloat occ = gmm_accs.occupancy()(i);
74  weights(i) = occ;
75  Vector<BaseFloat> mean(dim, kSetZero);
76  SpMatrix<BaseFloat> covar(dim, kSetZero);
77 
78  // If the occupancy for a Gaussian is very low, set it to a small value.
79  if (occ < 1e-10) {
80  weights(i) = 1e-10;
81  mean.SetRandn();
82  Vector<BaseFloat> diag(mean.Dim());
83  diag.Set(1.0);
84  covar.AddDiagVec(1.0, diag);
85  tot_low_occ++;
86  // This is the typical case.
87  } else {
88  mean.CopyRowFromMat(gmm_accs.mean_accumulator(), i);
89  mean.Scale(1.0 / occ);
90  covar.CopyFromSp(gmm_accs.covariance_accumulator()[i]);
91  covar.Scale(1.0 / occ);
92  covar.AddVec2(-1.0, mean); // subtract squared means.
93  }
94  means.CopyRowFromVec(mean, i);
95 
96  // Floor variance Eigenvalues.
97  BaseFloat floor = std::max(
98  static_cast<BaseFloat>(gmm_opts.variance_floor),
99  static_cast<BaseFloat>(covar.MaxAbsEig() / gmm_opts.max_condition));
100  int32 floored = covar.ApplyFloor(floor);
101  if (floored) {
102  tot_floored += floored;
103  gauss_floored++;
104  }
105  covar.InvertDouble();
106  invcovars.push_back(covar);
107  }
108  weights.Scale(1.0 / weights.Sum());
109  fgmm.SetWeights(weights);
110  fgmm.SetInvCovarsAndMeans(invcovars, means);
111  int32 num_bad = fgmm.ComputeGconsts();
112  KALDI_LOG << "FullGmm has " << num_bad << " bad GConsts";
113 
114  if (tot_floored > 0) {
115  KALDI_WARN << tot_floored << " variances floored in " << gauss_floored
116  << " Gaussians.";
117  }
118  if (tot_low_occ > 0) {
119  KALDI_WARN << tot_low_occ << " out of " << num_gauss
120  << " Gaussians had very low occupancy.";
121  }
122 
123  WriteKaldiObject(fgmm, model_out_filename, binary_write);
124 
125  KALDI_LOG << "Written model to " << model_out_filename;
126  } catch(const std::exception &e) {
127  std::cerr << e.what() << '\n';
128  return -1;
129  }
130 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
Configuration variables like variance floor, minimum occupancy, etc.
Definition: mle-full-gmm.h:38
Definition for Gaussian Mixture Model with full covariances.
Definition: full-gmm.h:40
BaseFloat variance_floor
Floor on eigenvalues of covariance matrices.
Definition: mle-full-gmm.h:44
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Register(OptionsItf *opts)
Definition: mle-full-gmm.h:56
Class for computing the maximum-likelihood estimates of the parameters of a Gaussian mixture model...
Definition: mle-full-gmm.h:74
#define KALDI_WARN
Definition: kaldi-error.h:130
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
BaseFloat max_condition
Maximum condition number of covariance matrices (apply floor to eigenvalues if they pass this)...
Definition: mle-full-gmm.h:47
#define KALDI_LOG
Definition: kaldi-error.h:133