All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
gmm-mixup.cc File Reference
Include dependency graph for gmm-mixup.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file gmm-mixup.cc.

References VectorBase< Real >::Dim(), ParseOptions::GetArg(), KALDI_ERR, KALDI_LOG, AmDiagGmm::MergeByCount(), ParseOptions::NumArgs(), AmDiagGmm::NumPdfs(), ParseOptions::PrintUsage(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), AmDiagGmm::SplitByCount(), Output::Stream(), Input::Stream(), AmDiagGmm::Write(), and TransitionModel::Write().

27  {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31 
32  const char *usage =
33  "Does GMM mixing up (and Gaussian merging)\n"
34  "Usage: gmm-mixup [options] <model-in> <state-occs-in> <model-out>\n"
35  "e.g. of mixing up:\n"
36  " gmm-mixup --mix-up=4000 1.mdl 1.occs 2.mdl\n"
37  "e.g. of merging:\n"
38  " gmm-mixup --merge=2000 1.mdl 1.occs 2.mdl\n";
39 
40  bool binary_write = true;
41  int32 mixup = 0;
42  int32 mixdown = 0;
43  BaseFloat perturb_factor = 0.01;
44  BaseFloat power = 0.2;
45  BaseFloat min_count = 20.0;
46 
47  ParseOptions po(usage);
48  po.Register("binary", &binary_write, "Write output in binary mode");
49  po.Register("mix-up", &mixup, "Increase number of mixture components to "
50  "this overall target.");
51  po.Register("min-count", &min_count,
52  "Minimum count enforced while mixing up.");
53  po.Register("mix-down", &mixdown, "If nonzero, merge mixture components to this "
54  "target.");
55  po.Register("power", &power, "If mixing up, power to allocate Gaussians to"
56  " states.");
57  po.Register("perturb-factor", &perturb_factor, "While mixing up, perturb "
58  "means by standard deviation times this factor.");
59 
60  po.Read(argc, argv);
61 
62  if (po.NumArgs() != 3) {
63  po.PrintUsage();
64  exit(1);
65  }
66 
67 
68  std::string model_in_filename = po.GetArg(1),
69  occs_in_filename = po.GetArg(2),
70  model_out_filename = po.GetArg(3);
71 
72  AmDiagGmm am_gmm;
73  TransitionModel trans_model;
74  {
75  bool binary_read;
76  Input ki(model_in_filename, &binary_read);
77  trans_model.Read(ki.Stream(), binary_read);
78  am_gmm.Read(ki.Stream(), binary_read);
79  }
80 
81  if (mixup != 0 || mixdown != 0) {
82 
83  Vector<BaseFloat> occs;
84  ReadKaldiObject(occs_in_filename, &occs);
85  if (occs.Dim() != am_gmm.NumPdfs())
86  KALDI_ERR << "Dimension of state occupancies " << occs.Dim()
87  << " does not match num-pdfs " << am_gmm.NumPdfs();
88 
89  if (mixdown != 0)
90  am_gmm.MergeByCount(occs, mixdown, power, min_count);
91 
92  if (mixup != 0)
93  am_gmm.SplitByCount(occs, mixup, perturb_factor,
94  power, min_count);
95  }
96 
97  {
98  Output ko(model_out_filename, binary_write);
99  trans_model.Write(ko.Stream(), binary_write);
100  am_gmm.Write(ko.Stream(), binary_write);
101  }
102 
103  KALDI_LOG << "Written model to " << model_out_filename;
104  } catch(const std::exception &e) {
105  std::cerr << e.what() << '\n';
106  return -1;
107  }
108 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
void MergeByCount(const Vector< BaseFloat > &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:125
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
int32 NumPdfs() const
Definition: am-diag-gmm.h:82
void Read(std::istream &is, bool binary)
#define KALDI_ERR
Definition: kaldi-error.h:127
void Write(std::ostream &os, bool binary) const
#define KALDI_LOG
Definition: kaldi-error.h:133
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:62
void SplitByCount(const Vector< BaseFloat > &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count)
Definition: am-diag-gmm.cc:102