nnet3-am-adjust-priors.cc File Reference
Include dependency graph for nnet3-am-adjust-priors.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet3
 

Functions

BaseFloat KlDivergence (const Vector< BaseFloat > &p, const Vector< BaseFloat > &q)
 
void PrintPriorDiagnostics (const Vector< BaseFloat > &old_priors, const Vector< BaseFloat > &new_priors)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 70 of file nnet3-am-adjust-priors.cc.

References ParseOptions::GetArg(), KALDI_ASSERT, KALDI_LOG, ParseOptions::NumArgs(), kaldi::nnet3::PrintPriorDiagnostics(), ParseOptions::PrintUsage(), AmNnetSimple::Priors(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), VectorBase< Real >::Scale(), AmNnetSimple::SetPriors(), Output::Stream(), Input::Stream(), VectorBase< Real >::Sum(), AmNnetSimple::Write(), and TransitionModel::Write().

70  {
71  try {
72  using namespace kaldi;
73  using namespace kaldi::nnet3;
74  typedef kaldi::int32 int32;
75 
76  const char *usage =
77  "Set the priors of the nnet3 neural net to the computed posterios from the net,\n"
78  "on typical data (e.g. training data). This is correct under more general\n"
79  "circumstances than using the priors of the class labels in the training data\n"
80  "\n"
81  "Typical usage of this program will involve computation of an average pdf-level\n"
82  "posterior with nnet3-compute or nnet3-compute-from-egs, piped into matrix-sum-rows\n"
83  "and then vector-sum, to compute the average posterior\n"
84  "\n"
85  "Usage: nnet3-am-adjust-priors [options] <nnet-in> <summed-posterior-vector-in> <nnet-out>\n"
86  "e.g.:\n"
87  " nnet3-am-adjust-priors final.mdl counts.vec final.mdl\n";
88 
89  bool binary_write = true;
90  BaseFloat prior_floor = 1.0e-15; // Have a very low prior floor, since this method
91  // isn't likely to have a problem with very improbable
92  // classes.
93 
94  ParseOptions po(usage);
95  po.Register("binary", &binary_write, "Write output in binary mode");
96  po.Register("prior-floor", &prior_floor, "When setting priors, floor for "
97  "priors (only used to avoid generating NaNs upon inversion)");
98 
99  po.Read(argc, argv);
100 
101  if (po.NumArgs() != 3) {
102  po.PrintUsage();
103  exit(1);
104  }
105 
106  std::string nnet_rxfilename = po.GetArg(1),
107  posterior_vec_rxfilename = po.GetArg(2),
108  nnet_wxfilename = po.GetArg(3);
109 
110  TransitionModel trans_model;
111  AmNnetSimple am_nnet;
112  {
113  bool binary_read;
114  Input ki(nnet_rxfilename, &binary_read);
115  trans_model.Read(ki.Stream(), binary_read);
116  am_nnet.Read(ki.Stream(), binary_read);
117  }
118 
119 
120  Vector<BaseFloat> posterior_vec;
121  ReadKaldiObject(posterior_vec_rxfilename, &posterior_vec);
122 
123  KALDI_ASSERT(posterior_vec.Sum() > 0.0);
124  posterior_vec.Scale(1.0 / posterior_vec.Sum()); // Renormalize
125 
126  Vector<BaseFloat> old_priors(am_nnet.Priors());
127 
128  PrintPriorDiagnostics(old_priors, posterior_vec);
129 
130  am_nnet.SetPriors(posterior_vec);
131 
132  {
133  Output ko(nnet_wxfilename, binary_write);
134  trans_model.Write(ko.Stream(), binary_write);
135  am_nnet.Write(ko.Stream(), binary_write);
136  }
137  KALDI_LOG << "Modified priors of neural network model and wrote it to "
138  << nnet_wxfilename;
139  return 0;
140  } catch(const std::exception &e) {
141  std::cerr << e.what() << '\n';
142  return -1;
143  }
144 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void SetPriors(const VectorBase< BaseFloat > &priors)
void PrintPriorDiagnostics(const Vector< BaseFloat > &old_priors, const Vector< BaseFloat > &new_priors)
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
const VectorBase< BaseFloat > & Priors() const
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
void Write(std::ostream &os, bool binary) const