DecodableNnetSimpleLooped Class Reference

#include <decodable-simple-looped.h>

Collaboration diagram for DecodableNnetSimpleLooped:

Public Member Functions

 DecodableNnetSimpleLooped (const DecodableNnetSimpleLoopedInfo &info, const MatrixBase< BaseFloat > &feats, const VectorBase< BaseFloat > *ivector=NULL, const MatrixBase< BaseFloat > *online_ivectors=NULL, int32 online_ivector_period=1)
 This constructor takes features as input, and you can either supply a single iVector input, estimated in batch-mode ('ivector'), or 'online' iVectors ('online_ivectors' and 'online_ivector_period', or none at all. More...
 
int32 NumFrames () const
 
int32 OutputDim () const
 
void GetOutputForFrame (int32 subsampled_frame, VectorBase< BaseFloat > *output)
 
BaseFloat GetOutput (int32 subsampled_frame, int32 pdf_id)
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (DecodableNnetSimpleLooped)
 
void AdvanceChunk ()
 
void AdvanceChunkInternal (const MatrixBase< BaseFloat > &input_feats, const VectorBase< BaseFloat > &ivector)
 
void GetCurrentIvector (int32 input_frame, Vector< BaseFloat > *ivector)
 
int32 GetIvectorDim () const
 

Private Attributes

const DecodableNnetSimpleLoopedInfoinfo_
 
NnetComputer computer_
 
const MatrixBase< BaseFloat > & feats_
 
int32 num_subsampled_frames_
 
const VectorBase< BaseFloat > * ivector_
 
const MatrixBase< BaseFloat > * online_ivector_feats_
 
int32 online_ivector_period_
 
Matrix< BaseFloatcurrent_log_post_
 
int32 num_chunks_computed_
 
int32 current_log_post_subsampled_offset_
 

Detailed Description

Definition at line 161 of file decodable-simple-looped.h.

Constructor & Destructor Documentation

◆ DecodableNnetSimpleLooped()

DecodableNnetSimpleLooped ( const DecodableNnetSimpleLoopedInfo info,
const MatrixBase< BaseFloat > &  feats,
const VectorBase< BaseFloat > *  ivector = NULL,
const MatrixBase< BaseFloat > *  online_ivectors = NULL,
int32  online_ivector_period = 1 
)

This constructor takes features as input, and you can either supply a single iVector input, estimated in batch-mode ('ivector'), or 'online' iVectors ('online_ivectors' and 'online_ivector_period', or none at all.

Note: it stores references to all arguments to the constructor, so don't delete them till this goes out of scope.

Parameters
[in]infoThis helper class contains all the static pre-computed information this class needs, and contains a pointer to the neural net.
[in]featsThe input feature matrix.
[in]ivectorIf you are using iVectors estimated in batch mode, a pointer to the iVector, else NULL.
[in]ivectorIf you are using iVectors estimated in batch mode, a pointer to the iVector, else NULL.
[in]online_ivectorsIf you are using iVectors estimated 'online' a pointer to the iVectors, else NULL.
[in]online_ivector_periodIf you are using iVectors estimated 'online' (i.e. if online_ivectors != NULL) gives the periodicity (in frames) with which the iVectors are estimated.

Definition at line 93 of file decodable-simple-looped.cc.

References DecodableNnetSimpleLooped::feats_, NnetSimpleLoopedComputationOptions::frame_subsampling_factor, DecodableNnetSimpleLooped::info_, KALDI_ASSERT, DecodableNnetSimpleLooped::num_subsampled_frames_, and DecodableNnetSimpleLoopedInfo::opts.

98  :
99  info_(info),
101  info_.nnet, NULL), // NULL is 'nnet_to_update'
102  feats_(feats),
103  ivector_(ivector), online_ivector_feats_(online_ivectors),
104  online_ivector_period_(online_ivector_period),
108  (feats_.NumRows() + info_.opts.frame_subsampling_factor - 1) /
110  KALDI_ASSERT(!(ivector != NULL && online_ivectors != NULL));
111  KALDI_ASSERT(!(online_ivectors != NULL && online_ivector_period <= 0 &&
112  "You need to set the --online-ivector-period option!"));
113 }
const VectorBase< BaseFloat > * ivector_
const MatrixBase< BaseFloat > * online_ivector_feats_
const NnetSimpleLoopedComputationOptions & opts
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const DecodableNnetSimpleLoopedInfo & info_

Member Function Documentation

◆ AdvanceChunk()

void AdvanceChunk ( )
private

Definition at line 137 of file decodable-simple-looped.cc.

References NnetComputer::AcceptInput(), NnetSimpleLoopedComputationOptions::acoustic_scale, CuMatrixBase< Real >::AddVecToRows(), DecodableNnetSimpleLooped::computer_, MatrixBase< Real >::CopyFromMat(), MatrixBase< Real >::CopyRowsFromVec(), DecodableNnetSimpleLooped::current_log_post_, DecodableNnetSimpleLooped::current_log_post_subsampled_offset_, VectorBase< Real >::Dim(), DecodableNnetSimpleLooped::feats_, NnetSimpleLoopedComputationOptions::frame_subsampling_factor, DecodableNnetSimpleLoopedInfo::frames_left_context, DecodableNnetSimpleLoopedInfo::frames_per_chunk, DecodableNnetSimpleLoopedInfo::frames_right_context, DecodableNnetSimpleLooped::GetCurrentIvector(), NnetComputer::GetOutputDestructive(), DecodableNnetSimpleLoopedInfo::has_ivectors, DecodableNnetSimpleLooped::info_, ComputationRequest::inputs, KALDI_ASSERT, kaldi::kUndefined, DecodableNnetSimpleLoopedInfo::log_priors, DecodableNnetSimpleLooped::num_chunks_computed_, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), DecodableNnetSimpleLoopedInfo::opts, DecodableNnetSimpleLoopedInfo::output_dim, DecodableNnetSimpleLoopedInfo::request1, DecodableNnetSimpleLoopedInfo::request2, Matrix< Real >::Resize(), CuMatrixBase< Real >::Row(), NnetComputer::Run(), CuMatrixBase< Real >::Scale(), and Matrix< Real >::Swap().

Referenced by DecodableNnetSimpleLooped::GetOutputForFrame().

137  {
138  int32 begin_input_frame, end_input_frame;
139  if (num_chunks_computed_ == 0) {
140  begin_input_frame = -info_.frames_left_context;
141  // note: end is last plus one.
142  end_input_frame = info_.frames_per_chunk + info_.frames_right_context;
143  } else {
144  begin_input_frame = num_chunks_computed_ * info_.frames_per_chunk +
146  end_input_frame = begin_input_frame + info_.frames_per_chunk;
147  }
148  CuMatrix<BaseFloat> feats_chunk(end_input_frame - begin_input_frame,
149  feats_.NumCols(), kUndefined);
150 
151  int32 num_features = feats_.NumRows();
152  if (begin_input_frame >= 0 && end_input_frame <= num_features) {
153  SubMatrix<BaseFloat> this_feats(feats_,
154  begin_input_frame,
155  end_input_frame - begin_input_frame,
156  0, feats_.NumCols());
157  feats_chunk.CopyFromMat(this_feats);
158  } else {
159  Matrix<BaseFloat> this_feats(end_input_frame - begin_input_frame,
160  feats_.NumCols());
161  for (int32 r = begin_input_frame; r < end_input_frame; r++) {
162  int32 input_frame = r;
163  if (input_frame < 0) input_frame = 0;
164  if (input_frame >= num_features) input_frame = num_features - 1;
165  this_feats.Row(r - begin_input_frame).CopyFromVec(
166  feats_.Row(input_frame));
167  }
168  feats_chunk.CopyFromMat(this_feats);
169  }
170  computer_.AcceptInput("input", &feats_chunk);
171 
172  if (info_.has_ivectors) {
173  KALDI_ASSERT(info_.request1.inputs.size() == 2);
174  // all but the 1st chunk should have 1 iVector, but no need
175  // to assume this.
176  int32 num_ivectors = (num_chunks_computed_ == 0 ?
177  info_.request1.inputs[1].indexes.size() :
178  info_.request2.inputs[1].indexes.size());
179  KALDI_ASSERT(num_ivectors > 0);
180 
181  Vector<BaseFloat> ivector;
182  // we just get the iVector from the last input frame we needed...
183  // we don't bother trying to be 'accurate' in getting the iVectors
184  // for their 'correct' frames, because in general using the
185  // iVector from as large 't' as possible will be better.
186  GetCurrentIvector(end_input_frame, &ivector);
187  Matrix<BaseFloat> ivectors(num_ivectors,
188  ivector.Dim());
189  ivectors.CopyRowsFromVec(ivector);
190  CuMatrix<BaseFloat> cu_ivectors(ivectors);
191  computer_.AcceptInput("ivector", &cu_ivectors);
192  }
193  computer_.Run();
194 
195  {
196  // Note: it's possible in theory that if you had weird recurrence that went
197  // directly from the output, the call to GetOutputDestructive() would cause
198  // a crash on the next chunk. If that happens, GetOutput() should be used
199  // instead of GetOutputDestructive(). But we don't anticipate this will
200  // happen in practice.
201  CuMatrix<BaseFloat> output;
202  computer_.GetOutputDestructive("output", &output);
203 
204  if (info_.log_priors.Dim() != 0) {
205  // subtract log-prior (divide by prior)
206  output.AddVecToRows(-1.0, info_.log_priors);
207  }
208  // apply the acoustic scale
209  output.Scale(info_.opts.acoustic_scale);
211  current_log_post_.Swap(&output);
212  }
216 
218 
220  (num_chunks_computed_ - 1) *
222 }
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void GetCurrentIvector(int32 input_frame, Vector< BaseFloat > *ivector)
kaldi::int32 int32
std::vector< IoSpecification > inputs
void Swap(Matrix< Real > *other)
Swaps the contents of *this and *other. Shallow swap.
const NnetSimpleLoopedComputationOptions & opts
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void GetOutputDestructive(const std::string &output_name, CuMatrix< BaseFloat > *output)
const DecodableNnetSimpleLoopedInfo & info_
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...

◆ AdvanceChunkInternal()

void AdvanceChunkInternal ( const MatrixBase< BaseFloat > &  input_feats,
const VectorBase< BaseFloat > &  ivector 
)
private

◆ GetCurrentIvector()

void GetCurrentIvector ( int32  input_frame,
Vector< BaseFloat > *  ivector 
)
private

Definition at line 225 of file decodable-simple-looped.cc.

References DecodableNnetSimpleLoopedInfo::has_ivectors, DecodableNnetSimpleLooped::info_, DecodableNnetSimpleLooped::ivector_, KALDI_ASSERT, KALDI_ERR, DecodableNnetSimpleLooped::online_ivector_feats_, and DecodableNnetSimpleLooped::online_ivector_period_.

Referenced by DecodableNnetSimpleLooped::AdvanceChunk().

226  {
227  if (!info_.has_ivectors)
228  return;
229  if (ivector_ != NULL) {
230  *ivector = *ivector_;
231  return;
232  } else if (online_ivector_feats_ == NULL) {
233  KALDI_ERR << "Neural net expects iVectors but none provided.";
234  }
236  int32 ivector_frame = input_frame / online_ivector_period_;
237  KALDI_ASSERT(ivector_frame >= 0);
238  if (ivector_frame >= online_ivector_feats_->NumRows())
239  ivector_frame = online_ivector_feats_->NumRows() - 1;
240  KALDI_ASSERT(ivector_frame >= 0 && "ivector matrix cannot be empty.");
241  *ivector = online_ivector_feats_->Row(ivector_frame);
242 }
const VectorBase< BaseFloat > * ivector_
const MatrixBase< BaseFloat > * online_ivector_feats_
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const DecodableNnetSimpleLoopedInfo & info_

◆ GetIvectorDim()

int32 GetIvectorDim ( ) const
private

Definition at line 127 of file decodable-simple-looped.cc.

References DecodableNnetSimpleLooped::ivector_, and DecodableNnetSimpleLooped::online_ivector_feats_.

127  {
128  if (ivector_ != NULL)
129  return ivector_->Dim();
130  else if (online_ivector_feats_ != NULL)
131  return online_ivector_feats_->NumCols();
132  else
133  return 0;
134 }
const VectorBase< BaseFloat > * ivector_
const MatrixBase< BaseFloat > * online_ivector_feats_

◆ GetOutput()

BaseFloat GetOutput ( int32  subsampled_frame,
int32  pdf_id 
)
inline

Definition at line 208 of file decodable-simple-looped.h.

References KALDI_ASSERT, and KALDI_DISALLOW_COPY_AND_ASSIGN.

Referenced by DecodableAmNnetSimpleLooped::LogLikelihood().

208  {
210  "Frames must be accessed in order.");
211  while (subsampled_frame >= current_log_post_subsampled_offset_ +
213  AdvanceChunk();
214  return current_log_post_(subsampled_frame -
216  pdf_id);
217  }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64

◆ GetOutputForFrame()

void GetOutputForFrame ( int32  subsampled_frame,
VectorBase< BaseFloat > *  output 
)

Definition at line 116 of file decodable-simple-looped.cc.

References DecodableNnetSimpleLooped::AdvanceChunk(), VectorBase< Real >::CopyFromVec(), DecodableNnetSimpleLooped::current_log_post_, DecodableNnetSimpleLooped::current_log_post_subsampled_offset_, KALDI_ASSERT, MatrixBase< Real >::NumRows(), and MatrixBase< Real >::Row().

Referenced by kaldi::nnet3::TestNnetDecodable().

117  {
119  "Frames must be accessed in order.");
120  while (subsampled_frame >= current_log_post_subsampled_offset_ +
122  AdvanceChunk();
123  output->CopyFromVec(current_log_post_.Row(
124  subsampled_frame - current_log_post_subsampled_offset_));
125 }
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( DecodableNnetSimpleLooped  )
private

◆ NumFrames()

int32 NumFrames ( ) const
inline

◆ OutputDim()

int32 OutputDim ( ) const
inline

Member Data Documentation

◆ computer_

NnetComputer computer_
private

Definition at line 237 of file decodable-simple-looped.h.

Referenced by DecodableNnetSimpleLooped::AdvanceChunk().

◆ current_log_post_

◆ current_log_post_subsampled_offset_

int32 current_log_post_subsampled_offset_
private

◆ feats_

◆ info_

◆ ivector_

◆ num_chunks_computed_

int32 num_chunks_computed_
private

Definition at line 259 of file decodable-simple-looped.h.

Referenced by DecodableNnetSimpleLooped::AdvanceChunk().

◆ num_subsampled_frames_

int32 num_subsampled_frames_
private

◆ online_ivector_feats_

const MatrixBase<BaseFloat>* online_ivector_feats_
private

◆ online_ivector_period_

int32 online_ivector_period_
private

The documentation for this class was generated from the following files: