nnet-pdf-prior.cc
Go to the documentation of this file.
1 // nnet/nnet-pdf-prior.cc
2 
3 // Copyright 2013 Brno University of Technology (Author: Karel Vesely);
4 // Arnab Ghoshal
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "nnet/nnet-pdf-prior.h"
22 
23 namespace kaldi {
24 namespace nnet1 {
25 
27  : prior_scale_(opts.prior_scale) {
28  if (opts.class_frame_counts == "") {
29  // class_frame_counts is empty, the PdfPrior is deactivated...
30  // (for example when 'nnet-forward' generates bottleneck features)
31  return;
32  }
33 
34  KALDI_LOG << "Computing pdf-priors from : " << opts.class_frame_counts;
35 
36  Vector<double> frame_counts, rel_freq, log_priors;
37  {
38  Input in;
40  frame_counts.Read(in.Stream(), false);
41  in.Close();
42  }
43 
44  // get relative frequencies,
45  rel_freq = frame_counts;
46  rel_freq.Scale(1.0/frame_counts.Sum());
47 
48  // get the log-prior,
49  log_priors = rel_freq;
50  log_priors.Add(1e-20);
51  log_priors.ApplyLog();
52 
53  // Make the priors for classes with low counts +inf (i.e. -log(0))
54  // such that the classes have 0 likelihood (i.e. -inf log-likelihood).
55  // We use sqrt(FLT_MAX) instead of -kLogZeroFloat to prevent NANs
56  // from appearing in computation.
57  int32 num_floored = 0;
58  for (int32 i = 0; i < log_priors.Dim(); i++) {
59  if (rel_freq(i) < opts.prior_floor) {
60  log_priors(i) = sqrt(FLT_MAX);
61  num_floored++;
62  }
63  }
64  KALDI_LOG << "Floored " << num_floored << " pdf-priors "
65  << "(hard-set to " << sqrt(FLT_MAX)
66  << ", which disables DNN output when decoding)";
67 
68  // sanity check,
69  KALDI_ASSERT(KALDI_ISFINITE(log_priors.Sum()));
70 
71  // push to GPU,
72  log_priors_ = Vector<BaseFloat>(log_priors);
73 }
74 
75 
77  if (log_priors_.Dim() == 0) {
78  KALDI_ERR << "--class-frame-counts is empty: Cannot initialize priors "
79  << "without the counts.";
80  }
81  if (log_priors_.Dim() != llk->NumCols()) {
82  KALDI_ERR << "Dimensionality mismatch,"
83  << " class_frame_counts " << log_priors_.Dim()
84  << " pdf_output_llk " << llk->NumCols();
85  }
87 }
88 
89 } // namespace nnet1
90 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
CuVector< BaseFloat > log_priors_
#define KALDI_ISFINITE(x)
Definition: kaldi-math.h:74
kaldi::int32 int32
bool OpenTextMode(const std::string &rxfilename)
Definition: kaldi-io-inl.h:30
void ApplyLog()
Apply natural log to all elements.
void SubtractOnLogpost(CuMatrixBase< BaseFloat > *llk)
Subtract pdf priors from log-posteriors to get pseudo log-likelihoods.
std::istream & Stream()
Definition: kaldi-io.cc:826
PdfPrior(const PdfPriorOptions &opts)
Initialize pdf-prior from options.
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1261
#define KALDI_ERR
Definition: kaldi-error.h:147
int32 Close()
Definition: kaldi-io.cc:761
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Add(Real c)
Add a constant to each element of a vector.
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.