sgmm2-init.cc File Reference
Include dependency graph for sgmm2-init.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file sgmm2-init.cc.

References AmSgmm2::ComputeNormalizers(), AmSgmm2::CopyGlobalsInitVecs(), ParseOptions::GetArg(), rnnlm::i, AmSgmm2::InitializeFromFullGmm(), KALDI_LOG, kaldi::kSgmmWriteAll, ParseOptions::NumArgs(), ContextDependency::NumPdfs(), ParseOptions::PrintUsage(), ContextDependency::Read(), ParseOptions::Read(), FullGmm::Read(), AmSgmm2::Read(), kaldi::ReadIntegerVector(), kaldi::ReadKaldiObject(), ParseOptions::Register(), Output::Stream(), Input::Stream(), TransitionModel::Write(), and AmSgmm2::Write().

28  {
29  try {
30  using namespace kaldi;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Initialize an SGMM from a trained full-covariance UBM and a specified"
35  " model topology.\n"
36  "Usage: sgmm2-init [options] <topology> <tree> <init-model> <sgmm-out>\n"
37  "The <init-model> argument can be a UBM (the default case) or another\n"
38  "SGMM (if the --init-from-sgmm flag is used).\n"
39  "For systems with two-level tree, use --pdf-map argument.";
40 
41  bool binary = true, init_from_sgmm = false, spk_dep_weights = false; // will
42  // make it true later.
43  int32 phn_space_dim = 0, spk_space_dim = 0;
44  std::string pdf_map_rxfilename;
45  double self_weight = 1.0;
46 
47  kaldi::ParseOptions po(usage);
48  po.Register("binary", &binary, "Write output in binary mode");
49  po.Register("phn-space-dim", &phn_space_dim, "Phonetic space dimension.");
50  po.Register("spk-space-dim", &spk_space_dim, "Speaker space dimension.");
51  po.Register("spk-dep-weights", &spk_dep_weights, "If true, have speaker-"
52  "dependent weights (symmetric SGMM)");
53  po.Register("init-from-sgmm", &init_from_sgmm,
54  "Initialize from another SGMM (instead of a UBM).");
55  po.Register("self-weight", &self_weight,
56  "If < 1.0, will be the weight of a pdf with its \"own\" mixture, "
57  "where we initialize each group with a number of mixtures. If"
58  "1.0, we initialize each group with just one mixture component.");
59  po.Register("pdf-map", &pdf_map_rxfilename,
60  "For systems with 2-level trees [SCTM systems], the file that "
61  "maps from pdfs to groups (from build-tree-two-level)");
62 
63  po.Read(argc, argv);
64 
65  if (po.NumArgs() != 4) {
66  po.PrintUsage();
67  exit(1);
68  }
69 
70  std::string topo_in_filename = po.GetArg(1),
71  tree_in_filename = po.GetArg(2),
72  init_model_filename = po.GetArg(3),
73  sgmm_out_filename = po.GetArg(4);
74 
75  ContextDependency ctx_dep;
76  {
77  bool binary_in;
78  Input ki(tree_in_filename.c_str(), &binary_in);
79  ctx_dep.Read(ki.Stream(), binary_in);
80  }
81 
82  std::vector<int32> pdf2group;
83  if (pdf_map_rxfilename != "") {
84  bool binary_in;
85  Input ki(pdf_map_rxfilename, &binary_in);
86  ReadIntegerVector(ki.Stream(), binary_in, &pdf2group);
87  } else {
88  for (int32 i = 0; i < ctx_dep.NumPdfs(); i++) pdf2group.push_back(i);
89  }
90 
91 
92  HmmTopology topo;
93  ReadKaldiObject(topo_in_filename, &topo);
94 
95  TransitionModel trans_model(ctx_dep, topo);
96 
97  kaldi::AmSgmm2 sgmm;
98  if (init_from_sgmm) {
99  kaldi::AmSgmm2 init_sgmm;
100  {
101  bool binary_read;
102  kaldi::Input ki(init_model_filename, &binary_read);
103  init_sgmm.Read(ki.Stream(), binary_read);
104  }
105  sgmm.CopyGlobalsInitVecs(init_sgmm, pdf2group, self_weight);
106  } else {
107  kaldi::FullGmm ubm;
108  {
109  bool binary_read;
110  kaldi::Input ki(init_model_filename, &binary_read);
111  ubm.Read(ki.Stream(), binary_read);
112  }
113  sgmm.InitializeFromFullGmm(ubm, pdf2group, phn_space_dim,
114  spk_space_dim, spk_dep_weights,
115  self_weight);
116  }
117  sgmm.ComputeNormalizers();
118 
119  {
120  kaldi::Output ko(sgmm_out_filename, binary);
121  trans_model.Write(ko.Stream(), binary);
122  sgmm.Write(ko.Stream(), binary, kaldi::kSgmmWriteAll);
123  }
124 
125  KALDI_LOG << "Written model to " << sgmm_out_filename;
126  } catch(const std::exception &e) {
127  std::cerr << e.what() << '\n';
128  return -1;
129  }
130 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Write(std::ostream &os, bool binary, SgmmWriteFlagsType write_params) const
Definition: am-sgmm2.cc:203
Class for definition of the subspace Gmm acoustic model.
Definition: am-sgmm2.h:231
A class for storing topology information for phones.
Definition: hmm-topology.h:93
Definition for Gaussian Mixture Model with full covariances.
Definition: full-gmm.h:40
void Read(std::istream &is, bool binary)
Definition: am-sgmm2.cc:89
void InitializeFromFullGmm(const FullGmm &gmm, const std::vector< int32 > &pdf2group, int32 phn_subspace_dim, int32 spk_subspace_dim, bool speaker_dependent_weights, BaseFloat self_weight)
Initializes the SGMM parameters from a full-covariance UBM.
Definition: am-sgmm2.cc:381
kaldi::int32 int32
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
virtual int32 NumPdfs() const
NumPdfs() returns the number of acoustic pdfs (they are numbered 0.. NumPdfs()-1).
Definition: context-dep.h:71
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void CopyGlobalsInitVecs(const AmSgmm2 &other, const std::vector< int32 > &pdf2group, BaseFloat self_weight)
Copies the global parameters from the supplied model, but sets the state vectors to zero...
Definition: am-sgmm2.cc:1183
void Read(std::istream &is, bool binary)
Definition: full-gmm.cc:813
void Read(std::istream &is, bool binary)
Read context-dependency object from disk; throws on error.
Definition: context-dep.cc:155
void ComputeNormalizers()
Computes the data-independent terms in the log-likelihood computation for each Gaussian component and...
Definition: am-sgmm2.cc:857
#define KALDI_LOG
Definition: kaldi-error.h:153