DecodableAmNnet Class Reference

DecodableAmNnet is a decodable object that decodes with a neural net acoustic model of type AmNnet. More...

#include <decodable-am-nnet.h>

Inheritance diagram for DecodableAmNnet:
Collaboration diagram for DecodableAmNnet:

Public Member Functions

 DecodableAmNnet (const TransitionModel &trans_model, const AmNnet &am_nnet, const CuMatrixBase< BaseFloat > &feats, bool pad_input=true, BaseFloat prob_scale=1.0)
 
virtual BaseFloat LogLikelihood (int32 frame, int32 transition_id)
 Returns the log likelihood, which will be negated in the decoder. More...
 
virtual int32 NumFramesReady () const
 The call NumFramesReady() will return the number of frames currently available for this decodable object. More...
 
virtual int32 NumIndices () const
 Returns the number of states in the acoustic model (they will be indexed one-based, i.e. More...
 
virtual bool IsLastFrame (int32 frame) const
 Returns true if this is the last frame. More...
 
- Public Member Functions inherited from DecodableInterface
virtual ~DecodableInterface ()
 

Protected Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (DecodableAmNnet)
 

Protected Attributes

const TransitionModeltrans_model_
 
Matrix< BaseFloatlog_probs_
 

Detailed Description

DecodableAmNnet is a decodable object that decodes with a neural net acoustic model of type AmNnet.

Definition at line 37 of file decodable-am-nnet.h.

Constructor & Destructor Documentation

◆ DecodableAmNnet()

DecodableAmNnet ( const TransitionModel trans_model,
const AmNnet am_nnet,
const CuMatrixBase< BaseFloat > &  feats,
bool  pad_input = true,
BaseFloat  prob_scale = 1.0 
)
inline

Definition at line 39 of file decodable-am-nnet.h.

References AmNnet::GetNnet(), KALDI_ASSERT, KALDI_WARN, Nnet::LeftContext(), DecodableAmNnet::log_probs_, kaldi::nnet2::NnetComputation(), TransitionModel::NumPdfs(), CuMatrixBase< Real >::NumRows(), AmNnet::Priors(), Nnet::RightContext(), and Matrix< Real >::Swap().

44  :
45  trans_model_(trans_model) {
46  // Note: we could make this more memory-efficient by doing the
47  // computation in smaller chunks than the whole utterance, and not
48  // storing the whole thing. We'll leave this for later.
49  int32 num_rows = feats.NumRows() -
50  (pad_input ? 0 : am_nnet.GetNnet().LeftContext() +
51  am_nnet.GetNnet().RightContext());
52  if (num_rows <= 0) {
53  KALDI_WARN << "Input with " << feats.NumRows() << " rows will produce "
54  << "empty output.";
55  return;
56  }
57  CuMatrix<BaseFloat> log_probs(num_rows, trans_model.NumPdfs());
58  // the following function is declared in nnet-compute.h
59  NnetComputation(am_nnet.GetNnet(), feats, pad_input, &log_probs);
60  log_probs.ApplyFloor(1.0e-20); // Avoid log of zero which leads to NaN.
61  log_probs.ApplyLog();
62  CuVector<BaseFloat> priors(am_nnet.Priors());
63  KALDI_ASSERT(priors.Dim() == trans_model.NumPdfs() &&
64  "Priors in neural network not set up.");
65  priors.ApplyLog();
66  // subtract log-prior (divide by prior)
67  log_probs.AddVecToRows(-1.0, priors);
68  // apply probability scale.
69  log_probs.Scale(prob_scale);
70  // Transfer the log-probs to the CPU for faster access by the
71  // decoding process.
72  log_probs_.Swap(&log_probs);
73  }
const TransitionModel & trans_model_
kaldi::int32 int32
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
Matrix< BaseFloat > log_probs_
void Swap(Matrix< Real > *other)
Swaps the contents of *this and *other. Shallow swap.
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Function Documentation

◆ IsLastFrame()

virtual bool IsLastFrame ( int32  frame) const
inlinevirtual

Returns true if this is the last frame.

Frames are zero-based, so the first frame is zero. IsLastFrame(-1) will return false, unless the file is empty (which is a case that I'm not sure all the code will handle, so be careful). Caution: the behavior of this function in an online setting is being changed somewhat. In future it may return false in cases where we haven't yet decided to terminate decoding, but later true if we decide to terminate decoding. The plan in future is to rely more on NumFramesReady(), and in future, IsLastFrame() would always return false in an online-decoding setting, and would only return true in a decoding-from-matrix setting where we want to allow the last delta or LDA features to be flushed out for compatibility with the baseline setup.

Implements DecodableInterface.

Definition at line 87 of file decodable-am-nnet.h.

References KALDI_ASSERT, and DecodableAmNnet::NumFramesReady().

87  {
88  KALDI_ASSERT(frame < NumFramesReady());
89  return (frame == NumFramesReady() - 1);
90  }
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( DecodableAmNnet  )
protected

◆ LogLikelihood()

virtual BaseFloat LogLikelihood ( int32  frame,
int32  index 
)
inlinevirtual

Returns the log likelihood, which will be negated in the decoder.

The "frame" starts from zero. You should verify that NumFramesReady() > frame before calling this.

Implements DecodableInterface.

Definition at line 77 of file decodable-am-nnet.h.

References DecodableAmNnet::log_probs_, DecodableAmNnet::trans_model_, and TransitionModel::TransitionIdToPdfFast().

Referenced by kaldi::nnet2::UnitTestNnetDecodable().

77  {
78  return log_probs_(frame,
79  trans_model_.TransitionIdToPdfFast(transition_id));
80  }
const TransitionModel & trans_model_
int32 TransitionIdToPdfFast(int32 trans_id) const
Matrix< BaseFloat > log_probs_

◆ NumFramesReady()

virtual int32 NumFramesReady ( ) const
inlinevirtual

The call NumFramesReady() will return the number of frames currently available for this decodable object.

This is for use in setups where you don't want the decoder to block while waiting for input. This is newly added as of Jan 2014, and I hope, going forward, to rely on this mechanism more than IsLastFrame to know when to stop decoding.

Reimplemented from DecodableInterface.

Definition at line 82 of file decodable-am-nnet.h.

References DecodableAmNnet::log_probs_, and MatrixBase< Real >::NumRows().

Referenced by DecodableAmNnet::IsLastFrame(), DecodableAmNnetParallel::IsLastFrame(), and kaldi::nnet2::UnitTestNnetDecodable().

82 { return log_probs_.NumRows(); }
Matrix< BaseFloat > log_probs_
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64

◆ NumIndices()

virtual int32 NumIndices ( ) const
inlinevirtual

Returns the number of states in the acoustic model (they will be indexed one-based, i.e.

from 1 to NumIndices(); this is for compatibility with OpenFst).

Implements DecodableInterface.

Definition at line 85 of file decodable-am-nnet.h.

References TransitionModel::NumTransitionIds(), and DecodableAmNnet::trans_model_.

85 { return trans_model_.NumTransitionIds(); }
const TransitionModel & trans_model_
int32 NumTransitionIds() const
Returns the total number of transition-ids (note, these are one-based).

Member Data Documentation

◆ log_probs_

◆ trans_model_


The documentation for this class was generated from the following file: