nnet3-am-adjust-priors.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-am-adjust-priors.cc
2 
3 // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "nnet3/am-nnet-simple.h"
23 #include "hmm/transition-model.h"
24 #include "tree/context-dep.h"
25 
26 namespace kaldi {
27 namespace nnet3 {
28 
29 
30 // Computes one-sided K-L divergence from p to q.
32  const Vector<BaseFloat> &q) {
33  BaseFloat sum_p = p.Sum(), sum_q = q.Sum();
34  if (fabs(sum_p - 1.0) > 0.01 || fabs(sum_q - 1.0) > 0.01) {
35  KALDI_WARN << "KlDivergence: vectors are not close to being normalized "
36  << sum_p << ", " << sum_q;
37  }
38  KALDI_ASSERT(p.Dim() == q.Dim());
39  double ans = 0.0;
40 
41  for (int32 i = 0; i < p.Dim(); i++) {
42  BaseFloat p_prob = p(i) / sum_p, q_prob = q(i) / sum_q;
43  ans += p_prob * Log(p_prob / q_prob);
44  }
45  return ans;
46 }
47 
49  const Vector<BaseFloat> &new_priors) {
50  if (old_priors.Dim() == 0) {
51  KALDI_LOG << "Model did not previously have priors attached.";
52  } else {
53  Vector<BaseFloat> diff_prior(new_priors);
54  diff_prior.AddVec(-1.0, old_priors);
55  diff_prior.ApplyAbs();
56  int32 max_index;
57  diff_prior.Max(&max_index);
58  KALDI_LOG << "Adjusting priors: largest absolute difference was for "
59  << "pdf " << max_index << ", " << old_priors(max_index)
60  << " -> " << new_priors(max_index);
61  KALDI_LOG << "Adjusting priors: K-L divergence from old to new is "
62  << KlDivergence(old_priors, new_priors);
63  }
64 }
65 
66 
67 } // namespace nnet3
68 } // namespace kaldi
69 
70 int main(int argc, char *argv[]) {
71  try {
72  using namespace kaldi;
73  using namespace kaldi::nnet3;
74  typedef kaldi::int32 int32;
75 
76  const char *usage =
77  "Set the priors of the nnet3 neural net to the computed posterios from the net,\n"
78  "on typical data (e.g. training data). This is correct under more general\n"
79  "circumstances than using the priors of the class labels in the training data\n"
80  "\n"
81  "Typical usage of this program will involve computation of an average pdf-level\n"
82  "posterior with nnet3-compute or nnet3-compute-from-egs, piped into matrix-sum-rows\n"
83  "and then vector-sum, to compute the average posterior\n"
84  "\n"
85  "Usage: nnet3-am-adjust-priors [options] <nnet-in> <summed-posterior-vector-in> <nnet-out>\n"
86  "e.g.:\n"
87  " nnet3-am-adjust-priors final.mdl counts.vec final.mdl\n";
88 
89  bool binary_write = true;
90  BaseFloat prior_floor = 1.0e-15; // Have a very low prior floor, since this method
91  // isn't likely to have a problem with very improbable
92  // classes.
93 
94  ParseOptions po(usage);
95  po.Register("binary", &binary_write, "Write output in binary mode");
96  po.Register("prior-floor", &prior_floor, "When setting priors, floor for "
97  "priors (only used to avoid generating NaNs upon inversion)");
98 
99  po.Read(argc, argv);
100 
101  if (po.NumArgs() != 3) {
102  po.PrintUsage();
103  exit(1);
104  }
105 
106  std::string nnet_rxfilename = po.GetArg(1),
107  posterior_vec_rxfilename = po.GetArg(2),
108  nnet_wxfilename = po.GetArg(3);
109 
110  TransitionModel trans_model;
111  AmNnetSimple am_nnet;
112  {
113  bool binary_read;
114  Input ki(nnet_rxfilename, &binary_read);
115  trans_model.Read(ki.Stream(), binary_read);
116  am_nnet.Read(ki.Stream(), binary_read);
117  }
118 
119 
120  Vector<BaseFloat> posterior_vec;
121  ReadKaldiObject(posterior_vec_rxfilename, &posterior_vec);
122 
123  KALDI_ASSERT(posterior_vec.Sum() > 0.0);
124  posterior_vec.Scale(1.0 / posterior_vec.Sum()); // Renormalize
125 
126  Vector<BaseFloat> old_priors(am_nnet.Priors());
127 
128  PrintPriorDiagnostics(old_priors, posterior_vec);
129 
130  am_nnet.SetPriors(posterior_vec);
131 
132  {
133  Output ko(nnet_wxfilename, binary_write);
134  trans_model.Write(ko.Stream(), binary_write);
135  am_nnet.Write(ko.Stream(), binary_write);
136  }
137  KALDI_LOG << "Modified priors of neural network model and wrote it to "
138  << nnet_wxfilename;
139  return 0;
140  } catch(const std::exception &e) {
141  std::cerr << e.what() << '\n';
142  return -1;
143  }
144 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void SetPriors(const VectorBase< BaseFloat > &priors)
void PrintPriorDiagnostics(const Vector< BaseFloat > &old_priors, const Vector< BaseFloat > &new_priors)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
double Log(double x)
Definition: kaldi-math.h:100
void Read(std::istream &is, bool binary)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
Real Max() const
Returns the maximum value of any element, or -infinity for the empty vector.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
const VectorBase< BaseFloat > & Priors() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
BaseFloat KlDivergence(const Vector< BaseFloat > &p, const Vector< BaseFloat > &q)
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void ApplyAbs()
Take absolute value of each of the elements.
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
void Write(std::ostream &os, bool binary) const