All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
gmm-est-hlda.cc File Reference
Include dependency graph for gmm-est-hlda.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file gmm-est-hlda.cc.

References count, HldaAccsDiagGmm::FeatureDim(), ParseOptions::GetArg(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, MatrixBase< Real >::LogDet(), HldaAccsDiagGmm::ModelDim(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), HldaAccsDiagGmm::Read(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), Output::Stream(), Input::Stream(), HldaAccsDiagGmm::Update(), AmDiagGmm::Write(), TransitionModel::Write(), and kaldi::WriteKaldiObject().

27  {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31 
32  const char *usage =
33  "Do HLDA update\n"
34  "Usage: gmm-est-hlda [options] <model-in> <full-hlda-mat-in> <model-out> <full-hlda-mat-out> <partial-hlda-mat-out> <stats-in1> <stats-in2> ... \n"
35  "e.g.: gmm-est-hlda 1.mdl 1.hldafull 2.mdl 2.hldafull 2.hlda 1.0.hacc 1.1.hacc ... \n";
36 
37  bool binary = true; // write in binary if true.
38 
39  ParseOptions po(usage);
40  po.Register("binary", &binary, "Write output in binary mode");
41 
42  po.Read(argc, argv);
43 
44  if (po.NumArgs() < 6) {
45  po.PrintUsage();
46  exit(1);
47  }
48 
49  std::string model_in_filename = po.GetArg(1),
50  hldafull_in_filename = po.GetArg(2),
51  model_out_filename = po.GetArg(3),
52  hldafull_out_filename = po.GetArg(4),
53  hldapart_out_filename = po.GetArg(5);
54 
55 
56  AmDiagGmm am_gmm;
57  TransitionModel trans_model;
58  {
59  bool binary;
60  Input ki(model_in_filename, &binary);
61  trans_model.Read(ki.Stream(), binary);
62  am_gmm.Read(ki.Stream(), binary);
63  }
64 
65  HldaAccsDiagGmm hlda_accs;
66  for (int32 i = 6; i <= po.NumArgs(); i++) {
67  std::string acc_filename = po.GetArg(i);
68  bool binary_in, add = true;
69  Input ki(acc_filename, &binary_in);
70  hlda_accs.Read(ki.Stream(), binary_in, add);
71  }
72 
73  Matrix<BaseFloat> hlda_mat_full;
74  ReadKaldiObject(hldafull_in_filename, &hlda_mat_full);
75  KALDI_ASSERT(hlda_mat_full.NumRows() == hlda_accs.FeatureDim()
76  && hlda_mat_full.NumCols() == hlda_accs.FeatureDim());
77 
78  Matrix<BaseFloat> hlda_mat_part(hlda_accs.ModelDim(),
79  hlda_accs.FeatureDim());
80 
81  BaseFloat objf_impr, count;
82  hlda_accs.Update(&am_gmm, &hlda_mat_full, &hlda_mat_part, &objf_impr, &count);
83 
84  KALDI_LOG << "Updated HLDA, total objf impr is " << (objf_impr/count)
85  << " over " << count << " frames, logdet is "
86  << hlda_mat_full.LogDet();
87 
88  WriteKaldiObject(hlda_mat_full, hldafull_out_filename, binary);
89  WriteKaldiObject(hlda_mat_part, hldapart_out_filename, binary);
90  {
91  Output ko(model_out_filename, binary);
92  trans_model.Write(ko.Stream(), binary);
93  am_gmm.Write(ko.Stream(), binary);
94  }
95  return 0;
96  } catch(const std::exception &e) {
97  std::cerr << e.what() << '\n';
98  return -1;
99  }
100 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
int32 FeatureDim()
Definition: hlda.h:51
Real LogDet(Real *det_sign=NULL) const
Returns logdet of matrix.
int32 ModelDim()
Definition: hlda.h:49
void Write(std::ostream &out_stream, bool binary) const
Definition: am-diag-gmm.cc:163
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:818
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
This class stores the compact form of the HLDA statistics, given a diagonal GMM.
Definition: hlda.h:38
void Update(AmDiagGmm *model, MatrixBase< BaseFloat > *Mfull, MatrixBase< BaseFloat > *M, BaseFloat *objf_impr_out, BaseFloat *count_out) const
The Update function does the ML update.
Definition: hlda.cc:192
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
Definition: kaldi-matrix.h:61
void Read(std::istream &is, bool binary, bool add=false)
Definition: hlda.cc:28
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:133
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147