nnet-adjust-priors.cc File Reference
Include dependency graph for nnet-adjust-priors.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet2
 

Functions

BaseFloat KlDivergence (const Vector< BaseFloat > &p, const Vector< BaseFloat > &q)
 
void PrintPriorDiagnostics (const Vector< BaseFloat > &old_priors, const Vector< BaseFloat > &new_priors)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 70 of file nnet-adjust-priors.cc.

References ParseOptions::GetArg(), KALDI_ASSERT, KALDI_LOG, ParseOptions::NumArgs(), kaldi::nnet2::PrintPriorDiagnostics(), ParseOptions::PrintUsage(), AmNnet::Priors(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), VectorBase< Real >::Scale(), AmNnet::SetPriors(), Output::Stream(), Input::Stream(), VectorBase< Real >::Sum(), AmNnet::Write(), and TransitionModel::Write().

70  {
71  try {
72  using namespace kaldi;
73  using namespace kaldi::nnet2;
74  typedef kaldi::int32 int32;
75 
76  const char *usage =
77  "Set the priors of the neural net to the computed posterios from the net,\n"
78  "on typical data (e.g. training data). This is correct under more general\n"
79  "circumstances than using the priors of the class labels in the training data\n"
80  "\n"
81  "Typical usage of this program will involve computation of an average pdf-level\n"
82  "posterior with nnet-compute or nnet-compute-from-egs, piped into matrix-sum-rows\n"
83  "and then vector-sum, to compute the average posterior\n"
84  "\n"
85  "Usage: nnet-adjust-priors [options] <nnet-in> <summed-posterior-vector-in> <nnet-out>\n"
86  "e.g.:\n"
87  " nnet-adjust-priors final.mdl prior.vec final.mdl\n";
88 
89  bool binary_write = true;
90  BaseFloat prior_floor = 1.0e-15; // Have a very low prior floor, since this method
91  // isn't likely to have a problem with very improbable
92  // classes.
93 
94  ParseOptions po(usage);
95  po.Register("binary", &binary_write, "Write output in binary mode");
96  po.Register("prior-floor", &prior_floor, "When setting priors, floor for "
97  "priors (only used to avoid generating NaNs upon inversion)");
98 
99  po.Read(argc, argv);
100 
101  if (po.NumArgs() != 3) {
102  po.PrintUsage();
103  exit(1);
104  }
105 
106  std::string nnet_rxfilename = po.GetArg(1),
107  posterior_vec_rxfilename = po.GetArg(2),
108  nnet_wxfilename = po.GetArg(3);
109 
110  TransitionModel trans_model;
111  AmNnet am_nnet;
112  {
113  bool binary_read;
114  Input ki(nnet_rxfilename, &binary_read);
115  trans_model.Read(ki.Stream(), binary_read);
116  am_nnet.Read(ki.Stream(), binary_read);
117  }
118 
119 
120  Vector<BaseFloat> posterior_vec;
121  ReadKaldiObject(posterior_vec_rxfilename, &posterior_vec);
122 
123  KALDI_ASSERT(posterior_vec.Sum() > 0.0);
124  posterior_vec.Scale(1.0 / posterior_vec.Sum()); // Renormalize
125 
126  Vector<BaseFloat> old_priors(am_nnet.Priors());
127 
128  PrintPriorDiagnostics(old_priors, posterior_vec);
129 
130  am_nnet.SetPriors(posterior_vec);
131 
132  {
133  Output ko(nnet_wxfilename, binary_write);
134  trans_model.Write(ko.Stream(), binary_write);
135  am_nnet.Write(ko.Stream(), binary_write);
136  }
137  KALDI_LOG << "Modified priors of neural network model and wrote it to "
138  << nnet_wxfilename;
139  return 0;
140  } catch(const std::exception &e) {
141  std::cerr << e.what() << '\n';
142  return -1;
143  }
144 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
void PrintPriorDiagnostics(const Vector< BaseFloat > &old_priors, const Vector< BaseFloat > &new_priors)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
const VectorBase< BaseFloat > & Priors() const
Definition: am-nnet.h:67
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void SetPriors(const VectorBase< BaseFloat > &priors)
Definition: am-nnet.cc:44
#define KALDI_LOG
Definition: kaldi-error.h:153