PdfPrior Class Reference

#include <nnet-pdf-prior.h>

Collaboration diagram for PdfPrior:

Public Member Functions

 PdfPrior (const PdfPriorOptions &opts)
 Initialize pdf-prior from options. More...
 
void SubtractOnLogpost (CuMatrixBase< BaseFloat > *llk)
 Subtract pdf priors from log-posteriors to get pseudo log-likelihoods. More...
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (PdfPrior)
 

Private Attributes

BaseFloat prior_scale_
 
CuVector< BaseFloatlog_priors_
 

Detailed Description

Definition at line 59 of file nnet-pdf-prior.h.

Constructor & Destructor Documentation

◆ PdfPrior()

PdfPrior ( const PdfPriorOptions opts)
explicit

Initialize pdf-prior from options.

Definition at line 26 of file nnet-pdf-prior.cc.

References VectorBase< Real >::Add(), VectorBase< Real >::ApplyLog(), PdfPriorOptions::class_frame_counts, Input::Close(), VectorBase< Real >::Dim(), rnnlm::i, KALDI_ASSERT, KALDI_ISFINITE, KALDI_LOG, PdfPrior::log_priors_, Input::OpenTextMode(), PdfPriorOptions::prior_floor, Vector< Real >::Read(), VectorBase< Real >::Scale(), Input::Stream(), and VectorBase< Real >::Sum().

27  : prior_scale_(opts.prior_scale) {
28  if (opts.class_frame_counts == "") {
29  // class_frame_counts is empty, the PdfPrior is deactivated...
30  // (for example when 'nnet-forward' generates bottleneck features)
31  return;
32  }
33 
34  KALDI_LOG << "Computing pdf-priors from : " << opts.class_frame_counts;
35 
36  Vector<double> frame_counts, rel_freq, log_priors;
37  {
38  Input in;
39  in.OpenTextMode(opts.class_frame_counts);
40  frame_counts.Read(in.Stream(), false);
41  in.Close();
42  }
43 
44  // get relative frequencies,
45  rel_freq = frame_counts;
46  rel_freq.Scale(1.0/frame_counts.Sum());
47 
48  // get the log-prior,
49  log_priors = rel_freq;
50  log_priors.Add(1e-20);
51  log_priors.ApplyLog();
52 
53  // Make the priors for classes with low counts +inf (i.e. -log(0))
54  // such that the classes have 0 likelihood (i.e. -inf log-likelihood).
55  // We use sqrt(FLT_MAX) instead of -kLogZeroFloat to prevent NANs
56  // from appearing in computation.
57  int32 num_floored = 0;
58  for (int32 i = 0; i < log_priors.Dim(); i++) {
59  if (rel_freq(i) < opts.prior_floor) {
60  log_priors(i) = sqrt(FLT_MAX);
61  num_floored++;
62  }
63  }
64  KALDI_LOG << "Floored " << num_floored << " pdf-priors "
65  << "(hard-set to " << sqrt(FLT_MAX)
66  << ", which disables DNN output when decoding)";
67 
68  // sanity check,
69  KALDI_ASSERT(KALDI_ISFINITE(log_priors.Sum()));
70 
71  // push to GPU,
72  log_priors_ = Vector<BaseFloat>(log_priors);
73 }
CuVector< BaseFloat > log_priors_
#define KALDI_ISFINITE(x)
Definition: kaldi-math.h:74
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153

Member Function Documentation

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( PdfPrior  )
private

◆ SubtractOnLogpost()

void SubtractOnLogpost ( CuMatrixBase< BaseFloat > *  llk)

Subtract pdf priors from log-posteriors to get pseudo log-likelihoods.

Definition at line 76 of file nnet-pdf-prior.cc.

References CuMatrixBase< Real >::AddVecToRows(), KALDI_ERR, PdfPrior::log_priors_, CuMatrixBase< Real >::NumCols(), and PdfPrior::prior_scale_.

Referenced by main().

76  {
77  if (log_priors_.Dim() == 0) {
78  KALDI_ERR << "--class-frame-counts is empty: Cannot initialize priors "
79  << "without the counts.";
80  }
81  if (log_priors_.Dim() != llk->NumCols()) {
82  KALDI_ERR << "Dimensionality mismatch,"
83  << " class_frame_counts " << log_priors_.Dim()
84  << " pdf_output_llk " << llk->NumCols();
85  }
86  llk->AddVecToRows(-prior_scale_, log_priors_);
87 }
CuVector< BaseFloat > log_priors_
#define KALDI_ERR
Definition: kaldi-error.h:147

Member Data Documentation

◆ log_priors_

CuVector<BaseFloat> log_priors_
private

Definition at line 69 of file nnet-pdf-prior.h.

Referenced by PdfPrior::PdfPrior(), and PdfPrior::SubtractOnLogpost().

◆ prior_scale_

BaseFloat prior_scale_
private

Definition at line 68 of file nnet-pdf-prior.h.

Referenced by PdfPrior::SubtractOnLogpost().


The documentation for this class was generated from the following files: