All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
fgmm-global-mixdown.cc File Reference
#include <queue>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/full-gmm.h"
Include dependency graph for fgmm-global-mixdown.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 26 of file fgmm-global-mixdown.cc.

References count, SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), rnnlm::i, rnnlm::j, KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, FullGmm::MergePreselect(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), FullGmm::NumGauss(), ParseOptions::PrintUsage(), ParseOptions::Read(), FullGmm::Read(), ParseOptions::Register(), Input::Stream(), SequentialTableReader< Holder >::Value(), and kaldi::WriteKaldiObject().

26  {
27  using namespace kaldi;
28  typedef kaldi::int32 int32;
29  try {
30  const char *usage =
31  "Merge Gaussians in a full-covariance GMM to get a smaller number;\n"
32  "this program supports a --gselect option which is used to select\n"
33  "\"good\" pairs of Gaussians to consider merging (pairs that most often\n"
34  "co-occur in the gselect information are considered). If no gselect\n"
35  "info supplied, we consider all pairs (very slow for big models).\n"
36  "Usage: fgmm-global-mixdown [options] <model-in> <model-out>\n"
37  "e.g.: fgmm-global-mixdown --gselect=gselect.1 --mixdown-target=120 1.ubm 2.ubm\n"
38  "Note: --mixdown-target option is required.\n";
39 
40  bool binary_write = true;
41  std::string gselect_rspecifier;
42  int32 mixdown_target = -1, num_pairs = 20000;
43  BaseFloat power = 0.75; // Power used in choosing pairs; between 0.5 and 1 make sense.
44  ParseOptions po(usage);
45  po.Register("binary", &binary_write, "Write output in binary mode");
46  po.Register("gselect", &gselect_rspecifier, "Gaussian-selection info, used "
47  "to select most promising pairs");
48  po.Register("num-pairs", &num_pairs, "Number of pairs of Gaussians to try merging "
49  "(only relevant if you use --gselect option");
50  po.Register("mixdown-target", &mixdown_target,
51  "Number of Gaussians we want in mixed-down GMM.");
52  po.Register("power", &power,
53  "Power used in choosing pairs from gselect (should be between 0.5 and 1)");
54 
55  po.Read(argc, argv);
56 
57  if (po.NumArgs() != 2) {
58  po.PrintUsage();
59  exit(1);
60  }
61 
62  std::string model_in_filename = po.GetArg(1),
63  model_out_filename = po.GetArg(2);
64 
65  KALDI_ASSERT(mixdown_target >= 0 && "--mixdown-target option is required and must be >0.");
66 
67  FullGmm fgmm;
68  {
69  bool binary_read;
70  Input ki(model_in_filename, &binary_read);
71  fgmm.Read(ki.Stream(), binary_read);
72  }
73  std::vector<std::pair<int32, int32> > pairs;
74  if (gselect_rspecifier == "") { // use all pairs.
75  for (int32 i = 0; i < fgmm.NumGauss(); i++)
76  for (int32 j = 0; j < i; j++) pairs.push_back(std::make_pair(i, j));
77  } else {
78  unordered_map<std::pair<int32, int32>, int32, PairHasher<int32> > counts; // co-occurrence map:
79  // if i <= j, then maps from (i,j) -> #co-occurrences in gselect info.
80  SequentialInt32VectorVectorReader gselect_reader(gselect_rspecifier);
81  for (; !gselect_reader.Done(); gselect_reader.Next()) {
82  const std::vector<std::vector<int32> > &gselect = gselect_reader.Value();
83  for (int32 i = 0; i < gselect.size(); i++) {
84  for (int32 j = 0; j < gselect[i].size(); j++) {
85  for (int32 k = 0; k < gselect[i].size(); k++) {
86  int32 idx1 = gselect[i][j], idx2 = gselect[i][k];
87  if (idx1 <= idx2) {
88  std::pair<int32, int32> pr(idx1, idx2);
89  if (counts.count(pr) == 0) counts[pr] = 1;
90  else counts[pr]++;
91  }
92  }
93  }
94  }
95  }
96  // take greatest according to count(i,j) / pow(count(i,i)*count(j,j), pow)
97  typedef std::pair<BaseFloat, std::pair<int32,int32> > QueueElem;
98  std::priority_queue<QueueElem> queue;
99  for (unordered_map<std::pair<int32, int32>, int32, PairHasher<int32> >::iterator iter = counts.begin();
100  iter != counts.end(); ++iter) {
101  int32 idx1 = iter->first.first, idx2 = iter->first.second,
102  count = iter->second;
103  if (idx1 != idx2) {
104  BaseFloat x = counts[std::make_pair(idx1,idx1)] * counts[std::make_pair(idx2, idx2)];
105  BaseFloat f = count / std::pow(x, power);
106  queue.push(std::make_pair(f, iter->first));
107  }
108  }
109  while (!queue.empty() && static_cast<int32>(pairs.size()) < num_pairs) {
110  KALDI_VLOG(2) << "Pair is " << queue.top().second.first << ", "
111  << queue.top().second.second;
112  pairs.push_back(queue.top().second); // the "num_pairs" "best" pairs of
113  queue.pop();
114  // indices, based on this co-occurrence statistic.
115  }
116  }
117  KALDI_LOG << "Selected " << pairs.size() << " pairs of Gaussians to merge, "
118  << "now doing merging.";
119  int32 orig_ngauss = fgmm.NumGauss();
120  BaseFloat like_change = fgmm.MergePreselect(mixdown_target, pairs);
121  int32 cur_ngauss = fgmm.NumGauss();
122  KALDI_LOG << "Mixed down GMM from " << orig_ngauss << " to "
123  << cur_ngauss << ", likelihood change was " << like_change;
124 
125  WriteKaldiObject(fgmm, model_out_filename, binary_write);
126 
127  KALDI_LOG << "Wrote model to " << model_out_filename;
128  } catch(const std::exception &e) {
129  std::cerr << e.what() << '\n';
130  return -1;
131  }
132 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
int32 NumGauss() const
Returns the number of mixture components in the GMM.
Definition: full-gmm.h:58
Definition for Gaussian Mixture Model with full covariances.
Definition: full-gmm.h:40
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Read(std::istream &is, bool binary)
Definition: full-gmm.cc:813
BaseFloat MergePreselect(int32 target_components, const std::vector< std::pair< int32, int32 > > &preselect_pairs)
Merge the components and remember the order in which the components were merged (flat list of pairs);...
Definition: full-gmm.cc:382
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:133
A hashing function-object for pairs of ints.
Definition: stl-utils.h:237