fgmm-global-merge.cc
Go to the documentation of this file.
1 // fgmmbin/fgmm-global-merge.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "util/common-utils.h"
21 #include "gmm/full-gmm.h"
22 #include "gmm/mle-full-gmm.h"
23 
24 namespace kaldi {
25 
28 void MergeFullGmm(const FullGmm &src, FullGmm *dst) {
29  FullGmm dst_copy;
30  dst_copy.CopyFromFullGmm(*dst);
31  KALDI_ASSERT(src.NumGauss() != 0 && dst_copy.NumGauss() != 0
32  && src.Dim() == dst_copy.Dim());
33  int32 src_num_mix = src.NumGauss(), dst_num_mix = dst_copy.NumGauss(),
34  num_mix = src_num_mix + dst_num_mix, dim = src.Dim();
35  dst->Resize(num_mix, dim);
36 
37  std::vector<SpMatrix<BaseFloat> > invcovars(num_mix);
38  for(int32 i = 0; i < dst_num_mix; i++) {
39  invcovars[i].Resize(dim);
40  invcovars[i].CopyFromSp(dst_copy.inv_covars()[i]);
41  }
42  for(int32 i = 0; i < src_num_mix; i++) {
43  invcovars[i+dst_num_mix].Resize(dim);
44  invcovars[i+dst_num_mix].CopyFromSp(src.inv_covars()[i]);
45  }
46  Matrix<BaseFloat> means_invcovars(num_mix, dim);
47  means_invcovars.Range(0, dst_num_mix, 0, dim).CopyFromMat(dst_copy.means_invcovars());
48  means_invcovars.Range(dst_num_mix, src_num_mix, 0, dim).CopyFromMat(src.means_invcovars());
49  dst->SetInvCovarsAndMeansInvCovars(invcovars, means_invcovars);
50 
51  Vector<BaseFloat> weights(num_mix); // initialized to zero.
52  // weight proportional to #Gaussians, so that if we combine a number of
53  // models with same #Gaussians, they all get the same weight.
54  BaseFloat src_weight = src_num_mix / static_cast<BaseFloat>(num_mix),
55  dst_weight = dst_num_mix / static_cast<BaseFloat>(num_mix);
56  weights.Range(0, dst_num_mix).AddVec(dst_weight, dst_copy.weights());
57  weights.Range(dst_num_mix, src_num_mix).AddVec(src_weight, src.weights());
58  dst->SetWeights(weights);
59  dst->ComputeGconsts();
60 }
61 
62 }
63 
64 
65 int main(int argc, char *argv[]) {
66  try {
67  using namespace kaldi;
68  typedef kaldi::int32 int32;
69 
70  const char *usage =
71  "Combine a number of GMMs into a larger GMM, with #Gauss = \n"
72  " sum(individual #Gauss)). Output full GMM, and a text file with\n"
73  " sizes of each individual GMM.\n"
74  "Usage: fgmm-global-merge [options] fgmm-out sizes-file-out fgmm-in1 fgmm-in2 ...\n";
75 
76  bool binary = true;
77  ParseOptions po(usage);
78  po.Register("binary", &binary, "Write output in binary mode");
79  po.Read(argc, argv);
80 
81  if (po.NumArgs() < 4) {
82  po.PrintUsage();
83  exit(1);
84  }
85 
86  std::string fgmm_out_filename = po.GetArg(1),
87  sizes_out_filename = po.GetArg(2);
88 
89  FullGmm fgmm;
90  Output sizes_ko(sizes_out_filename, false); // false == not binary.
91 
92  for (int i = 3, max = po.NumArgs(); i <= max; i++) {
93  std::string stats_in_filename = po.GetArg(i);
94  bool binary_read;
95  Input ki(stats_in_filename, &binary_read);
96  if (i==3) {
97  fgmm.Read(ki.Stream(), binary_read);
98  sizes_ko.Stream() << fgmm.NumGauss() << ' ';
99  } else {
100  FullGmm fgmm2;
101  fgmm2.Read(ki.Stream(), binary_read);
102  sizes_ko.Stream() << fgmm2.NumGauss() << ' ';
103  MergeFullGmm(fgmm2, &fgmm);
104  }
105  }
106  sizes_ko.Stream() << "\n";
107 
108  // Write out the model
109  WriteKaldiObject(fgmm, fgmm_out_filename, binary);
110  KALDI_LOG << "Written merged GMM to " << fgmm_out_filename;
111  } catch(const std::exception &e) {
112  std::cerr << e.what() << '\n';
113  return -1;
114  }
115 }
116 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void SetWeights(const Vector< Real > &w)
Mutators for both float or double.
Definition: full-gmm-inl.h:31
const std::vector< SpMatrix< BaseFloat > > & inv_covars() const
Definition: full-gmm.h:146
int32 Dim() const
Returns the dimensionality of the Gaussian mean vectors.
Definition: full-gmm.h:60
int32 ComputeGconsts()
Sets the gconsts.
Definition: full-gmm.cc:92
Definition for Gaussian Mixture Model with full covariances.
Definition: full-gmm.h:40
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
void CopyFromFullGmm(const FullGmm &fullgmm)
Copies from given FullGmm.
Definition: full-gmm.cc:65
std::istream & Stream()
Definition: kaldi-io.cc:826
void Resize(int32 nMix, int32 dim)
Resizes arrays to this dim. Does not initialize data.
Definition: full-gmm.cc:41
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void MergeFullGmm(const FullGmm &src, FullGmm *dst)
merges GMMs by appending Gaussians in "src" to "dst".
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: full-gmm.h:58
const Vector< BaseFloat > & weights() const
Definition: full-gmm.h:144
void SetInvCovarsAndMeansInvCovars(const std::vector< SpMatrix< Real > > &invcovars, const Matrix< Real > &means_invcovars)
Use this if setting both, in the class&#39;s native format.
Definition: full-gmm-inl.h:67
void Read(std::istream &is, bool binary)
Definition: full-gmm.cc:813
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
SubMatrix< Real > Range(const MatrixIndexT row_offset, const MatrixIndexT num_rows, const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Return a sub-part of matrix.
Definition: kaldi-matrix.h:202
const Matrix< BaseFloat > & means_invcovars() const
Definition: full-gmm.h:145
#define KALDI_LOG
Definition: kaldi-error.h:153
int main(int argc, char *argv[])
SubVector< Real > Range(const MatrixIndexT o, const MatrixIndexT l)
Returns a sub-vector of a vector (a range of elements).
Definition: kaldi-vector.h:94