SingleUtteranceNnet2DecoderThreaded Class Reference

You will instantiate this class when you want to decode a single utterance using the online-decoding setup for neural nets. More...

#include <online-nnet2-decoding-threaded.h>

Collaboration diagram for SingleUtteranceNnet2DecoderThreaded:

Public Member Functions

 SingleUtteranceNnet2DecoderThreaded (const OnlineNnet2DecodingThreadedConfig &config, const TransitionModel &tmodel, const nnet2::AmNnet &am_nnet, const fst::Fst< fst::StdArc > &fst, const OnlineNnet2FeaturePipelineInfo &feature_info, const OnlineIvectorExtractorAdaptationState &adaptation_state, const OnlineCmvnState &cmvn_state)
 
void AcceptWaveform (BaseFloat samp_freq, const VectorBase< BaseFloat > &wave_part)
 You call this to provide this class with more waveform to decode. More...
 
int32 NumWaveformPiecesPending ()
 Returns the number of pieces of waveform that are still waiting to be processed. More...
 
void InputFinished ()
 You call this to inform the class that no more waveform will be provided; this allows it to flush out the last few frames of features, and is necessary if you want to call Wait() to wait until all decoding is done. More...
 
void TerminateDecoding ()
 You can call this if you don't want the decoding to proceed further with this utterance. More...
 
void Wait ()
 This call will block until all the data has been decoded; it must only be called after either InputFinished() has been called or TerminateDecoding() has been called; otherwise, to call it is an error. More...
 
void FinalizeDecoding ()
 Finalizes the decoding. More...
 
int32 NumFramesReceivedApprox () const
 Returns *approximately* (ignoring end effects), the number of frames of data that we expect given the amount of data that the pipeline has received via AcceptWaveform(). More...
 
int32 NumFramesDecoded () const
 Returns the number of frames currently decoded. More...
 
void GetLattice (bool end_of_utterance, CompactLattice *clat, BaseFloat *final_relative_cost) const
 Gets the lattice. More...
 
void GetBestPath (bool end_of_utterance, Lattice *best_path, BaseFloat *final_relative_cost) const
 Outputs an FST corresponding to the single best path through the current lattice. More...
 
bool EndpointDetected (const OnlineEndpointConfig &config)
 This function calls EndpointDetected from online-endpoint.h, with the required arguments. More...
 
void GetAdaptationState (OnlineIvectorExtractorAdaptationState *adaptation_state)
 Outputs the adaptation state of the feature pipeline to "adaptation_state". More...
 
void GetCmvnState (OnlineCmvnState *cmvn_state)
 Outputs the OnlineCmvnState of the feature pipeline to "cmvn_stat". More...
 
BaseFloat GetRemainingWaveform (Vector< BaseFloat > *waveform_out) const
 Gets the remaining, un-decoded part of the waveform and returns the sample rate. More...
 
 ~SingleUtteranceNnet2DecoderThreaded ()
 

Private Member Functions

void AbortAllThreads (bool error)
 
void WaitForAllThreads ()
 
bool RunNnetEvaluationInternal ()
 
void ProcessLoglikes (const CuVector< BaseFloat > &log_inv_prior, CuMatrixBase< BaseFloat > *loglikes)
 
bool FeatureComputation (int32 num_frames_output)
 
bool RunDecoderSearchInternal ()
 

Static Private Member Functions

static void RunNnetEvaluation (SingleUtteranceNnet2DecoderThreaded *me)
 
static void RunDecoderSearch (SingleUtteranceNnet2DecoderThreaded *me)
 

Private Attributes

OnlineNnet2DecodingThreadedConfig config_
 
const nnet2::AmNnetam_nnet_
 
const TransitionModeltmodel_
 
BaseFloat sampling_rate_
 
int64 num_samples_received_
 
bool input_finished_
 
std::deque< Vector< BaseFloat > *> input_waveform_
 
ThreadSynchronizer waveform_synchronizer_
 
OnlineNnet2FeaturePipeline feature_pipeline_
 
std::mutex feature_pipeline_mutex_
 
std::deque< Vector< BaseFloat > *> processed_waveform_
 
int64 num_samples_discarded_
 
OnlineSilenceWeighting silence_weighting_
 
std::mutex silence_weighting_mutex_
 
DecodableMatrixMappedOffset decodable_
 
int32 num_frames_decoded_
 
ThreadSynchronizer decodable_synchronizer_
 
LatticeFasterOnlineDecoder decoder_
 
std::mutex decoder_mutex_
 
std::thread threads_ [2]
 
bool abort_
 
bool error_
 

Detailed Description

You will instantiate this class when you want to decode a single utterance using the online-decoding setup for neural nets.

Each time this class is created, it creates three background threads, and the feature extraction, neural net evaluation, and search aspects of decoding all happen in different threads. Note: we assume that all calls to its public interface happen from a single thread.

Definition at line 190 of file online-nnet2-decoding-threaded.h.

Constructor & Destructor Documentation

◆ SingleUtteranceNnet2DecoderThreaded()

SingleUtteranceNnet2DecoderThreaded ( const OnlineNnet2DecodingThreadedConfig config,
const TransitionModel tmodel,
const nnet2::AmNnet am_nnet,
const fst::Fst< fst::StdArc > &  fst,
const OnlineNnet2FeaturePipelineInfo feature_info,
const OnlineIvectorExtractorAdaptationState adaptation_state,
const OnlineCmvnState cmvn_state 
)

Definition at line 112 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::decoder_, SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, LatticeFasterDecoderTpl< FST, Token >::InitDecoding(), SingleUtteranceNnet2DecoderThreaded::RunDecoderSearch(), SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluation(), OnlineNnet2FeaturePipeline::SetAdaptationState(), OnlineNnet2FeaturePipeline::SetCmvnState(), and SingleUtteranceNnet2DecoderThreaded::threads_.

119  :
120  config_(config), am_nnet_(am_nnet), tmodel_(tmodel), sampling_rate_(0.0),
122  feature_pipeline_(feature_info),
124  silence_weighting_(tmodel, feature_info.silence_weighting_config),
125  decodable_(tmodel),
127  abort_(false), error_(false) {
128  // if the user supplies an adaptation state that was not freshly initialized,
129  // it means that we take the adaptation state from the previous
130  // utterance(s)... this only makes sense if theose previous utterance(s) are
131  // believed to be from the same speaker.
132  feature_pipeline_.SetAdaptationState(adaptation_state);
133  feature_pipeline_.SetCmvnState(cmvn_state);
134  // spawn threads.
135  threads_[0] = std::thread(RunNnetEvaluation, this);
137  threads_[1] = std::thread(RunDecoderSearch, this);
138 }
static void RunDecoderSearch(SingleUtteranceNnet2DecoderThreaded *me)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void InitDecoding()
InitDecoding initializes the decoding, and should only be used if you intend to call AdvanceDecoding(...
void SetAdaptationState(const OnlineIvectorExtractorAdaptationState &adaptation_state)
Set the adaptation state to a particular value, e.g.
void SetCmvnState(const OnlineCmvnState &cmvn_state)
Set the CMVN state to a particular value.
static void RunNnetEvaluation(SingleUtteranceNnet2DecoderThreaded *me)

◆ ~SingleUtteranceNnet2DecoderThreaded()

Definition at line 141 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::abort_, SingleUtteranceNnet2DecoderThreaded::AbortAllThreads(), SingleUtteranceNnet2DecoderThreaded::input_waveform_, SingleUtteranceNnet2DecoderThreaded::processed_waveform_, and SingleUtteranceNnet2DecoderThreaded::WaitForAllThreads().

141  {
142  if (!abort_) {
143  // If we have not already started the process of aborting the threads, do so now.
144  bool error = false;
145  AbortAllThreads(error);
146  }
147  // join all the threads (this avoids leaving zombie threads around, or threads
148  // that might be accessing deconstructed object).
150  while (!input_waveform_.empty()) {
151  delete input_waveform_.front();
152  input_waveform_.pop_front();
153  }
154  while (!processed_waveform_.empty()) {
155  delete processed_waveform_.front();
156  processed_waveform_.pop_front();
157  }
158 }
std::deque< Vector< BaseFloat > *> processed_waveform_

Member Function Documentation

◆ AbortAllThreads()

◆ AcceptWaveform()

void AcceptWaveform ( BaseFloat  samp_freq,
const VectorBase< BaseFloat > &  wave_part 
)

You call this to provide this class with more waveform to decode.

This call is, for all practical purposes, non-blocking.

Definition at line 160 of file online-nnet2-decoding-threaded.cc.

References VectorBase< Real >::Dim(), SingleUtteranceNnet2DecoderThreaded::input_waveform_, KALDI_ASSERT, KALDI_ERR, ThreadSynchronizer::kProducer, ThreadSynchronizer::Lock(), SingleUtteranceNnet2DecoderThreaded::num_samples_received_, SingleUtteranceNnet2DecoderThreaded::sampling_rate_, ThreadSynchronizer::UnlockSuccess(), and SingleUtteranceNnet2DecoderThreaded::waveform_synchronizer_.

162  {
163  if (sampling_rate_ <= 0.0)
164  sampling_rate_ = sampling_rate;
165  else {
166  KALDI_ASSERT(sampling_rate == sampling_rate_);
167  }
168  num_samples_received_ += wave_part.Dim();
169 
170  if (wave_part.Dim() == 0) return;
172  KALDI_ERR << "Failure locking mutex: decoding aborted.";
173  }
174 
175  Vector<BaseFloat> *new_part = new Vector<BaseFloat>(wave_part);
176  input_waveform_.push_back(new_part);
177  // we always unlock with success because there is no buffer size limitation
178  // for the waveform so no reason why we might wait.
180 }
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ EndpointDetected()

bool EndpointDetected ( const OnlineEndpointConfig config)

This function calls EndpointDetected from online-endpoint.h, with the required arguments.

Definition at line 651 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::decoder_, SingleUtteranceNnet2DecoderThreaded::decoder_mutex_, kaldi::EndpointDetected(), SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, OnlineNnet2FeaturePipeline::FrameShiftInSeconds(), and SingleUtteranceNnet2DecoderThreaded::tmodel_.

652  {
653  std::lock_guard<std::mutex> lock(decoder_mutex_);
654  return kaldi::EndpointDetected(config, tmodel_,
656  decoder_);
657 }
bool EndpointDetected(const OnlineEndpointConfig &config, int32 num_frames_decoded, int32 trailing_silence_frames, BaseFloat frame_shift_in_seconds, BaseFloat final_relative_cost)
This function returns true if this set of endpointing rules thinks we should terminate decoding...

◆ FeatureComputation()

bool FeatureComputation ( int32  num_frames_output)
private

Definition at line 412 of file online-nnet2-decoding-threaded.cc.

References OnlineNnet2FeaturePipeline::AcceptWaveform(), SingleUtteranceNnet2DecoderThreaded::config_, SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, OnlineNnet2FeaturePipeline::FrameShiftInSeconds(), SingleUtteranceNnet2DecoderThreaded::input_finished_, SingleUtteranceNnet2DecoderThreaded::input_waveform_, OnlineNnet2FeaturePipeline::InputFinished(), OnlineNnet2FeaturePipeline::IsLastFrame(), KALDI_ASSERT, ThreadSynchronizer::kConsumer, ThreadSynchronizer::Lock(), OnlineNnet2DecodingThreadedConfig::nnet_batch_size, SingleUtteranceNnet2DecoderThreaded::num_frames_decoded_, SingleUtteranceNnet2DecoderThreaded::num_samples_discarded_, OnlineNnet2FeaturePipeline::NumFramesReady(), SingleUtteranceNnet2DecoderThreaded::processed_waveform_, SingleUtteranceNnet2DecoderThreaded::sampling_rate_, ThreadSynchronizer::UnlockFailure(), ThreadSynchronizer::UnlockSuccess(), and SingleUtteranceNnet2DecoderThreaded::waveform_synchronizer_.

Referenced by SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal().

413  {
414 
415  int32 num_frames_ready = feature_pipeline_.NumFramesReady(),
416  num_frames_usable = num_frames_ready - num_frames_consumed;
417  bool features_done = feature_pipeline_.IsLastFrame(num_frames_ready - 1);
418  KALDI_ASSERT(num_frames_usable >= 0);
419  if (features_done) {
420  return true; // nothing to do. (but not an error).
421  } else {
422  if (num_frames_usable >= config_.nnet_batch_size)
423  return true; // We don't need more data yet.
424 
425  // Now try to get more data, if we can.
427  return false;
428  }
429  // we've got the lock.
430  if (input_waveform_.empty()) { // we got no data
431  if (input_finished_ &&
433  // the main thread called InputFinished() and set input_finished_, and
434  // we haven't yet registered that fact. This is progress so
435  // unlock with UnlockSuccess().
438  } else {
439  // there is no progress. Unlock with UnlockFailure() so the next call to
440  // waveform_synchronizer_.Lock() will lock.
442  }
443  } else { // we got some data. Only take enough of the waveform to
444  // give us a maximum nnet batch size of frames to decode.
445  while (num_frames_usable < config_.nnet_batch_size &&
446  !input_waveform_.empty()) {
448  processed_waveform_.push_back(input_waveform_.front());
449  input_waveform_.pop_front();
450  num_frames_ready = feature_pipeline_.NumFramesReady();
451  num_frames_usable = num_frames_ready - num_frames_consumed;
452  }
453  // Delete already-processed pieces of waveform if we have already decoded
454  // those frames. (If not already decoded, we keep them around for the
455  // sake of GetRemainingWaveform()).
456  int32 samples_shift_per_frame =
458  while (!processed_waveform_.empty() &&
460  samples_shift_per_frame * num_frames_decoded_) {
462  delete processed_waveform_.front();
463  processed_waveform_.pop_front();
464  }
466  }
467  }
468 }
virtual bool IsLastFrame(int32 frame) const
Returns true if this is the last frame.
void InputFinished()
If you call InputFinished(), it tells the class you won&#39;t be providing any more waveform.
kaldi::int32 int32
virtual int32 NumFramesReady() const
returns the feature dimension.
std::deque< Vector< BaseFloat > *> processed_waveform_
void AcceptWaveform(BaseFloat sampling_rate, const VectorBase< BaseFloat > &waveform)
Accept more data to process.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ FinalizeDecoding()

void FinalizeDecoding ( )

Finalizes the decoding.

Cleans up and prunes remaining tokens, so the final lattice is faster to obtain. May not be called unless either InputFinished() or TerminateDecoding() has been called. If InputFinished() was called, it calls Wait() to ensure that the decoding has finished (it's not an error if you already called Wait()).

Definition at line 228 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::decoder_, LatticeFasterDecoderTpl< FST, Token >::FinalizeDecoding(), KALDI_ERR, and SingleUtteranceNnet2DecoderThreaded::threads_.

228  {
229  if (threads_[0].joinable()) {
230  KALDI_ERR << "It is an error to call FinalizeDecoding before Wait().";
231  }
233 }
void FinalizeDecoding()
This function may be optionally called after AdvanceDecoding(), when you do not plan to decode any fu...
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ GetAdaptationState()

void GetAdaptationState ( OnlineIvectorExtractorAdaptationState adaptation_state)

Outputs the adaptation state of the feature pipeline to "adaptation_state".

This mostly stores stats for iVector estimation, and will generally be called at the end of an utterance, assuming it's a scenario where each speaker is seen for more than one utterance. You may only call this function after either calling TerminateDecoding() or InputFinished, and then Wait(). Otherwise it is an error.

Definition at line 283 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, SingleUtteranceNnet2DecoderThreaded::feature_pipeline_mutex_, and OnlineNnet2FeaturePipeline::GetAdaptationState().

284  {
285  std::lock_guard<std::mutex> lock(feature_pipeline_mutex_);
286  // If this blocks, it shouldn't be for very long.
287  feature_pipeline_.GetAdaptationState(adaptation_state);
288 }
void GetAdaptationState(OnlineIvectorExtractorAdaptationState *adaptation_state) const
Get the adaptation state; you may want to call this before destroying this object, to get adaptation state that can be used to improve decoding of later utterances of this speaker.

◆ GetBestPath()

void GetBestPath ( bool  end_of_utterance,
Lattice best_path,
BaseFloat final_relative_cost 
) const

Outputs an FST corresponding to the single best path through the current lattice.

If "use_final_probs" is true AND we reached the final-state of the graph then it will include those as final-probs, else it will treat all final-probs as one. If no frames have been decoded yet, it will set best_path to a lattice with a single state that is final and with unit weight (no cost). The output to final_relative_cost (if non-NULL) is a number >= 0 that's closer to 0 if a final-state were close to the best-likelihood state active on the last frame, at the time we got the best path.

Definition at line 323 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::decoder_, SingleUtteranceNnet2DecoderThreaded::decoder_mutex_, LatticeFasterDecoderTpl< FST, Token >::FinalRelativeCost(), LatticeFasterOnlineDecoderTpl< FST >::GetBestPath(), LatticeFasterDecoderTpl< FST, Token >::NumFramesDecoded(), and LatticeWeightTpl< BaseFloat >::One().

326  {
327  std::lock_guard<std::mutex> lock(decoder_mutex_);
328  if (decoder_.NumFramesDecoded() == 0) {
329  // It's possible that this if-statement is not necessary because we'd get this
330  // anyway if we just called GetBestPath on the decoder.
331  best_path->DeleteStates();
332  best_path->SetFinal(best_path->AddState(),
334  if (final_relative_cost != NULL)
335  *final_relative_cost = std::numeric_limits<BaseFloat>::infinity();
336  } else {
337  decoder_.GetBestPath(best_path,
338  end_of_utterance);
339  if (final_relative_cost != NULL)
340  *final_relative_cost = decoder_.FinalRelativeCost();
341  }
342 }
static const LatticeWeightTpl One()
BaseFloat FinalRelativeCost() const
FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives more information.
bool GetBestPath(Lattice *ofst, bool use_final_probs=true) const
Outputs an FST corresponding to the single best path through the lattice.

◆ GetCmvnState()

void GetCmvnState ( OnlineCmvnState cmvn_state)

Outputs the OnlineCmvnState of the feature pipeline to "cmvn_stat".

This stores cmvn stats for the non-iVector features, and will be called at the end of an utterance, assuming it's a scenario where each speaker is seen for more than one utterance. You may only call this function after either calling TerminateDecoding() or InputFinished, and then Wait(). Otherwise it is an error.

Definition at line 290 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, SingleUtteranceNnet2DecoderThreaded::feature_pipeline_mutex_, and OnlineNnet2FeaturePipeline::GetCmvnState().

291  {
292  std::lock_guard<std::mutex> lock(feature_pipeline_mutex_);
293  // If this blocks, it shouldn't be for very long.
294  feature_pipeline_.GetCmvnState(cmvn_state);
295 }
void GetCmvnState(OnlineCmvnState *cmvn_state)

◆ GetLattice()

void GetLattice ( bool  end_of_utterance,
CompactLattice clat,
BaseFloat final_relative_cost 
) const

Gets the lattice.

The output lattice has any acoustic scaling in it (which will typically be desirable in an online-decoding context); if you want an un-scaled lattice, scale it using ScaleLattice() with the inverse of the acoustic weight. "end_of_utterance" will be true if you want the final-probs to be included. If this is at the end of the utterance, you might want to first call FinalizeDecoding() first; this will make this call return faster. If no frames have been decoded yet, it will set clat to a lattice with a single state that is final and with unit weight (no cost or alignment). The output to final_relative_cost (if non-NULL) is a number >= 0 that's closer to 0 if a final-state was close to the best-likelihood state active on the last frame, at the time we obtained the lattice.

Definition at line 297 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::config_, SingleUtteranceNnet2DecoderThreaded::decoder_, SingleUtteranceNnet2DecoderThreaded::decoder_mutex_, OnlineNnet2DecodingThreadedConfig::decoder_opts, LatticeFasterDecoderConfig::det_opts, LatticeFasterDecoderConfig::determinize_lattice, fst::DeterminizeLatticePhonePrunedWrapper(), LatticeFasterDecoderTpl< FST, Token >::FinalRelativeCost(), LatticeFasterDecoderTpl< FST, Token >::GetRawLattice(), KALDI_ERR, LatticeFasterDecoderConfig::lattice_beam, LatticeFasterDecoderTpl< FST, Token >::NumFramesDecoded(), CompactLatticeWeightTpl< WeightType, IntType >::One(), and SingleUtteranceNnet2DecoderThreaded::tmodel_.

300  {
301  clat->DeleteStates();
302  decoder_mutex_.lock();
303  if (final_relative_cost != NULL)
304  *final_relative_cost = decoder_.FinalRelativeCost();
305  if (decoder_.NumFramesDecoded() == 0) {
306  decoder_mutex_.unlock();
307  clat->SetFinal(clat->AddState(),
309  return;
310  }
311  Lattice raw_lat;
312  decoder_.GetRawLattice(&raw_lat, end_of_utterance);
313  decoder_mutex_.unlock();
314 
316  KALDI_ERR << "--determinize-lattice=false option is not supported at the moment";
317 
320  tmodel_, &raw_lat, lat_beam, clat, config_.decoder_opts.det_opts);
321 }
bool GetRawLattice(Lattice *ofst, bool use_final_probs=true) const
Outputs an FST corresponding to the raw, state-level tracebacks.
float BaseFloat
Definition: kaldi-types.h:29
static const CompactLatticeWeightTpl< WeightType, IntType > One()
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
fst::DeterminizeLatticePhonePrunedOptions det_opts
BaseFloat FinalRelativeCost() const
FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives more information.
bool DeterminizeLatticePhonePrunedWrapper(const kaldi::TransitionModel &trans_model, MutableFst< kaldi::LatticeArc > *ifst, double beam, MutableFst< kaldi::CompactLatticeArc > *ofst, DeterminizeLatticePhonePrunedOptions opts)
This function is a wrapper of DeterminizeLatticePhonePruned() that works for Lattice type FSTs...

◆ GetRemainingWaveform()

BaseFloat GetRemainingWaveform ( Vector< BaseFloat > *  waveform_out) const

Gets the remaining, un-decoded part of the waveform and returns the sample rate.

May only be called after Wait(), and it only makes sense to call this if you called TerminateDecoding() before Wait(). The idea is that you can then provide this un-decoded piece of waveform to another decoder.

Definition at line 235 of file online-nnet2-decoding-threaded.cc.

References VectorBase< Real >::Dim(), SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, OnlineNnet2FeaturePipeline::FrameShiftInSeconds(), rnnlm::i, SingleUtteranceNnet2DecoderThreaded::input_waveform_, KALDI_ASSERT, KALDI_ERR, kaldi::kUndefined, SingleUtteranceNnet2DecoderThreaded::num_frames_decoded_, SingleUtteranceNnet2DecoderThreaded::num_samples_discarded_, SingleUtteranceNnet2DecoderThreaded::processed_waveform_, VectorBase< Real >::Range(), Vector< Real >::Resize(), SingleUtteranceNnet2DecoderThreaded::sampling_rate_, and SingleUtteranceNnet2DecoderThreaded::threads_.

236  {
237  if (threads_[0].joinable()) {
238  KALDI_ERR << "It is an error to call GetRemainingWaveform before Wait().";
239  }
240  int64 num_samples_stored = 0; // number of samples we still have.
241  std::vector< Vector<BaseFloat>* > all_pieces;
242  std::deque< Vector<BaseFloat>* >::const_iterator iter;
243  for (iter = processed_waveform_.begin(); iter != processed_waveform_.end();
244  ++iter) {
245  num_samples_stored += (*iter)->Dim();
246  all_pieces.push_back(*iter);
247  }
248  for (iter = input_waveform_.begin(); iter != input_waveform_.end(); ++iter) {
249  num_samples_stored += (*iter)->Dim();
250  all_pieces.push_back(*iter);
251  }
252  int64 samples_shift_per_frame =
254  int64 num_samples_to_discard = samples_shift_per_frame * num_frames_decoded_;
255  KALDI_ASSERT(num_samples_to_discard >= num_samples_discarded_);
256 
257  // num_samp_discard is how many samples we must discard from our stored
258  // samples.
259  int64 num_samp_discard = num_samples_to_discard - num_samples_discarded_,
260  num_samp_keep = num_samples_stored - num_samp_discard;
261  KALDI_ASSERT(num_samp_discard <= num_samples_stored && num_samp_keep >= 0);
262  waveform->Resize(num_samp_keep, kUndefined);
263  int32 offset = 0; // offset in output waveform. assume output waveform is no
264  // larger than int32.
265  for (size_t i = 0; i < all_pieces.size(); i++) {
266  Vector<BaseFloat> *this_piece = all_pieces[i];
267  int32 this_dim = this_piece->Dim();
268  if (num_samp_discard >= this_dim) {
269  num_samp_discard -= this_dim;
270  } else {
271  // normal case is num_samp_discard = 0.
272  int32 this_dim_keep = this_dim - num_samp_discard;
273  waveform->Range(offset, this_dim_keep).CopyFromVec(
274  this_piece->Range(num_samp_discard, this_dim_keep));
275  offset += this_dim_keep;
276  num_samp_discard = 0;
277  }
278  }
279  KALDI_ASSERT(offset == num_samp_keep && num_samp_discard == 0);
280  return sampling_rate_;
281 }
kaldi::int32 int32
std::deque< Vector< BaseFloat > *> processed_waveform_
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ InputFinished()

void InputFinished ( )

You call this to inform the class that no more waveform will be provided; this allows it to flush out the last few frames of features, and is necessary if you want to call Wait() to wait until all decoding is done.

After calling InputFinished() you cannot call AcceptWaveform any more.

Definition at line 203 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::input_finished_, KALDI_ASSERT, KALDI_ERR, ThreadSynchronizer::kProducer, ThreadSynchronizer::Lock(), ThreadSynchronizer::UnlockSuccess(), and SingleUtteranceNnet2DecoderThreaded::waveform_synchronizer_.

203  {
204  // setting input_finished_ = true informs the feature-processing pipeline
205  // to expect no more input, and to flush out the last few frames if there
206  // is any latency in the pipeline (e.g. due to pitch).
208  KALDI_ERR << "Failure locking mutex: decoding aborted.";
209  }
210  KALDI_ASSERT(!input_finished_ && "InputFinished called twice");
211  input_finished_ = true;
213 }
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ NumFramesDecoded()

int32 NumFramesDecoded ( ) const

Returns the number of frames currently decoded.

Caution: don't rely on the lattice having exactly this number if you get it after this call, as it may increase after this– unless you've already called either TerminateDecoding() or InputFinished(), followed by Wait().

Definition at line 352 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::decoder_, SingleUtteranceNnet2DecoderThreaded::decoder_mutex_, and LatticeFasterDecoderTpl< FST, Token >::NumFramesDecoded().

352  {
353  std::lock_guard<std::mutex> lock(decoder_mutex_);
354  return decoder_.NumFramesDecoded();
355 }

◆ NumFramesReceivedApprox()

int32 NumFramesReceivedApprox ( ) const

Returns *approximately* (ignoring end effects), the number of frames of data that we expect given the amount of data that the pipeline has received via AcceptWaveform().

(ignores small end effects). This might be useful in application code to compare with NumFramesDecoded() and gauge how much latency there is.

Definition at line 198 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, OnlineNnet2FeaturePipeline::FrameShiftInSeconds(), SingleUtteranceNnet2DecoderThreaded::num_samples_received_, and SingleUtteranceNnet2DecoderThreaded::sampling_rate_.

◆ NumWaveformPiecesPending()

int32 NumWaveformPiecesPending ( )

Returns the number of pieces of waveform that are still waiting to be processed.

This may be useful for calling code to judge whether to supply more waveform or to wait.

Definition at line 182 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::input_waveform_, KALDI_ERR, ThreadSynchronizer::kProducer, ThreadSynchronizer::Lock(), ThreadSynchronizer::UnlockSuccess(), and SingleUtteranceNnet2DecoderThreaded::waveform_synchronizer_.

182  {
183  // Note RE locking: what we really want here is just to lock the mutex. As a
184  // side effect, because of the way the synchronizer code works, it will also
185  // increment the semaphore and might wake up the consumer thread. This will
186  // possibly make it do a little useless work (go around a loop once), but
187  // won't really do any harm. Perhaps we should have implemented a version of
188  // the Lock function that takes no arguments.
190  KALDI_ERR << "Failure locking mutex: decoding aborted.";
191  }
192  int32 ans = input_waveform_.size();
194  return ans;
195 }
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ ProcessLoglikes()

void ProcessLoglikes ( const CuVector< BaseFloat > &  log_inv_prior,
CuMatrixBase< BaseFloat > *  loglikes 
)
private

Definition at line 395 of file online-nnet2-decoding-threaded.cc.

References OnlineNnet2DecodingThreadedConfig::acoustic_scale, CuMatrixBase< Real >::AddVecToRows(), CuMatrixBase< Real >::ApplyFloor(), CuMatrixBase< Real >::ApplyLog(), SingleUtteranceNnet2DecoderThreaded::config_, CuMatrixBase< Real >::NumRows(), and CuMatrixBase< Real >::Scale().

Referenced by SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal().

397  {
398  if (cu_loglikes->NumRows() != 0) {
399  cu_loglikes->ApplyFloor(1.0e-20);
400  cu_loglikes->ApplyLog();
401  // take the log-posteriors and turn them into pseudo-log-likelihoods by
402  // dividing by the pdf priors; then scale by the acoustic scale.
403  cu_loglikes->AddVecToRows(1.0, log_inv_prior);
404  cu_loglikes->Scale(config_.acoustic_scale);
405  }
406 }

◆ RunDecoderSearch()

void RunDecoderSearch ( SingleUtteranceNnet2DecoderThreaded me)
staticprivate

Definition at line 371 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::abort_, SingleUtteranceNnet2DecoderThreaded::AbortAllThreads(), KALDI_ERR, KALDI_WARN, and SingleUtteranceNnet2DecoderThreaded::RunDecoderSearchInternal().

Referenced by SingleUtteranceNnet2DecoderThreaded::SingleUtteranceNnet2DecoderThreaded().

372  {
373  try {
374  if (!me->RunDecoderSearchInternal() && !me->abort_)
375  KALDI_ERR << "Returned abnormally and abort was not called";
376  } catch(const std::exception &e) {
377  KALDI_WARN << "Caught exception: " << e.what();
378  // if an error happened in one thread, we need to make sure the other threads can exit too.
379  bool error = true;
380  me->AbortAllThreads(error);
381  }
382 }
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150

◆ RunDecoderSearchInternal()

bool RunDecoderSearchInternal ( )
private

Definition at line 615 of file online-nnet2-decoding-threaded.cc.

References OnlineSilenceWeighting::Active(), LatticeFasterDecoderTpl< FST, Token >::AdvanceDecoding(), OnlineSilenceWeighting::ComputeCurrentTraceback(), SingleUtteranceNnet2DecoderThreaded::config_, SingleUtteranceNnet2DecoderThreaded::decodable_, SingleUtteranceNnet2DecoderThreaded::decodable_synchronizer_, OnlineNnet2DecodingThreadedConfig::decode_batch_size, SingleUtteranceNnet2DecoderThreaded::decoder_, SingleUtteranceNnet2DecoderThreaded::decoder_mutex_, DecodableMatrixMappedOffset::IsLastFrame(), KALDI_ASSERT, ThreadSynchronizer::kConsumer, ThreadSynchronizer::Lock(), SingleUtteranceNnet2DecoderThreaded::num_frames_decoded_, LatticeFasterDecoderTpl< FST, Token >::NumFramesDecoded(), DecodableMatrixMappedOffset::NumFramesReady(), SingleUtteranceNnet2DecoderThreaded::silence_weighting_, SingleUtteranceNnet2DecoderThreaded::silence_weighting_mutex_, ThreadSynchronizer::UnlockFailure(), and ThreadSynchronizer::UnlockSuccess().

Referenced by SingleUtteranceNnet2DecoderThreaded::RunDecoderSearch().

615  {
616  int32 num_frames_decoded = 0; // this is just a copy of decoder_->NumFramesDecoded();
617  while (true) { // decode at most one frame each loop.
619  return false; // AbortAllThreads() called.
620  if (decodable_.NumFramesReady() <= num_frames_decoded) {
621  // no frames available to decode.
622  KALDI_ASSERT(decodable_.NumFramesReady() == num_frames_decoded);
623  if (decodable_.IsLastFrame(num_frames_decoded - 1)) {
625  return true; // exit from this thread; we're done.
626  } else {
627  // we were not able to advance the decoding due to no available
628  // input. The next call will ensure that the next call to
629  // decodable_synchronizer_.Lock() will wait.
631  return false;
632  }
633  } else {
634  // Decode at most config_.decode_batch_size frames (e.g. 1 or 2).
635  decoder_mutex_.lock();
637  num_frames_decoded = decoder_.NumFramesDecoded();
638  if (silence_weighting_.Active()) {
639  std::lock_guard<std::mutex> lock(silence_weighting_mutex_);
640  // the next function does not trace back all the way; it's very fast.
642  }
643  decoder_mutex_.unlock();
644  num_frames_decoded_ = num_frames_decoded;
646  return false;
647  }
648  }
649 }
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
kaldi::int32 int32
void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl< FST > &decoder)
void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
This will decode until there are no more frames ready in the decodable object.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual bool IsLastFrame(int32 frame) const
Returns true if this is the last frame.

◆ RunNnetEvaluation()

void RunNnetEvaluation ( SingleUtteranceNnet2DecoderThreaded me)
staticprivate

Definition at line 357 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::abort_, SingleUtteranceNnet2DecoderThreaded::AbortAllThreads(), KALDI_ERR, KALDI_WARN, and SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal().

Referenced by SingleUtteranceNnet2DecoderThreaded::SingleUtteranceNnet2DecoderThreaded().

358  {
359  try {
360  if (!me->RunNnetEvaluationInternal() && !me->abort_)
361  KALDI_ERR << "Returned abnormally and abort was not called";
362  } catch(const std::exception &e) {
363  KALDI_WARN << "Caught exception: " << e.what();
364  // if an error happened in one thread, we need to make sure the other
365  // threads can exit too.
366  bool error = true;
367  me->AbortAllThreads(error);
368  }
369 }
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150

◆ RunNnetEvaluationInternal()

bool RunNnetEvaluationInternal ( )
private

Definition at line 470 of file online-nnet2-decoding-threaded.cc.

References DecodableMatrixMappedOffset::AcceptLoglikes(), OnlineSilenceWeighting::Active(), SingleUtteranceNnet2DecoderThreaded::am_nnet_, CuVectorBase< Real >::ApplyFloor(), SingleUtteranceNnet2DecoderThreaded::config_, SingleUtteranceNnet2DecoderThreaded::decodable_, SingleUtteranceNnet2DecoderThreaded::decodable_synchronizer_, OnlineNnet2FeaturePipeline::Dim(), SingleUtteranceNnet2DecoderThreaded::feature_pipeline_, SingleUtteranceNnet2DecoderThreaded::feature_pipeline_mutex_, SingleUtteranceNnet2DecoderThreaded::FeatureComputation(), DecodableMatrixMappedOffset::FirstAvailableFrame(), OnlineSilenceWeighting::GetDeltaWeights(), OnlineNnet2FeaturePipeline::GetFrame(), AmNnet::GetNnet(), rnnlm::i, DecodableMatrixMappedOffset::InputIsFinished(), OnlineNnet2FeaturePipeline::IsLastFrame(), OnlineNnet2FeaturePipeline::IvectorFeature(), KALDI_ASSERT, ThreadSynchronizer::kProducer, ThreadSynchronizer::Lock(), OnlineNnet2DecodingThreadedConfig::max_loglikes_copy, OnlineNnet2DecodingThreadedConfig::nnet_batch_size, SingleUtteranceNnet2DecoderThreaded::num_frames_decoded_, OnlineNnet2FeaturePipeline::NumFramesReady(), OnlineIvectorFeature::NumFramesReady(), MatrixBase< Real >::NumRows(), CuMatrixBase< Real >::NumRows(), AmNnet::Priors(), SingleUtteranceNnet2DecoderThreaded::ProcessLoglikes(), Matrix< Real >::Resize(), SingleUtteranceNnet2DecoderThreaded::silence_weighting_, SingleUtteranceNnet2DecoderThreaded::silence_weighting_mutex_, Matrix< Real >::Swap(), CuMatrix< Real >::Swap(), ThreadSynchronizer::UnlockFailure(), ThreadSynchronizer::UnlockSuccess(), and OnlineIvectorFeature::UpdateFrameWeights().

Referenced by SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluation().

470  {
471  // if any of the Lock/Unlock functions return false, it's because AbortAllThreads()
472  // was called.
473 
474  // This object is responsible for keeping track of the context, and avoiding
475  // re-computing things we've already computed.
476  bool pad_input = true;
477  nnet2::NnetOnlineComputer computer(am_nnet_.GetNnet(), pad_input);
478 
479  // we declare the following as CuVector just to enable GPU support, but
480  // we expect this code to be run on CPU in the normal case.
481  CuVector<BaseFloat> log_inv_prior(am_nnet_.Priors());
482  log_inv_prior.ApplyFloor(1.0e-20); // should have no effect.
483  log_inv_prior.ApplyLog();
484  log_inv_prior.Scale(-1.0);
485 
486  // we'll have num_frames_consumed >= num_frames_output; num_frames_consumed is
487  // the number of feature frames consumed by the nnet computation,
488  // num_frames_output is the number of frames of loglikes the nnet computation
489  // has produced, which may be less than num_frames_consumed due to the
490  // right-context of the network.
491  int32 num_frames_consumed = 0, num_frames_output = 0;
492 
493  while (true) {
494  bool last_time = false;
495 
496  /****** Begin locking of feature pipeline mutex. ******/
498  if (!FeatureComputation(num_frames_consumed)) { // error
499  feature_pipeline_mutex_.unlock();
500  return false;
501  }
502  // take care of silence weighting.
503  if (silence_weighting_.Active() &&
504  feature_pipeline_.IvectorFeature() != NULL) {
506  std::vector<std::pair<int32, BaseFloat> > delta_weights;
509  &delta_weights);
510  silence_weighting_mutex_.unlock();
512  }
513 
514  int32 num_frames_ready = feature_pipeline_.NumFramesReady(),
515  num_frames_usable = num_frames_ready - num_frames_consumed;
516  bool features_done = feature_pipeline_.IsLastFrame(num_frames_ready - 1);
517 
518  int32 num_frames_evaluate = std::min<int32>(num_frames_usable,
520 
521  Matrix<BaseFloat> feats;
522  if (num_frames_evaluate > 0) {
523  // we have something to do...
524  feats.Resize(num_frames_evaluate, feature_pipeline_.Dim());
525  for (int32 i = 0; i < num_frames_evaluate; i++) {
526  int32 t = num_frames_consumed + i;
527  SubVector<BaseFloat> feat(feats, i);
528  feature_pipeline_.GetFrame(t, &feat);
529  }
530  }
531  /****** End locking of feature pipeline mutex. ******/
532  feature_pipeline_mutex_.unlock();
533 
534  CuMatrix<BaseFloat> cu_loglikes;
535 
536  if (feats.NumRows() == 0) {
537  if (features_done) {
538  // flush out the last few frames. Note: this is the only place from
539  // which we check feature_buffer_finished_, and we'll exit the loop, so
540  // if we reach here it must be the first time it was true.
541  last_time = true;
542  computer.Flush(&cu_loglikes);
543  ProcessLoglikes(log_inv_prior, &cu_loglikes);
544  }
545  } else {
546  CuMatrix<BaseFloat> cu_feats;
547  cu_feats.Swap(&feats); // If we don't have a GPU (and not having a GPU is
548  // the normal expected use-case for this code),
549  // this would be a lightweight operation, swapping
550  // pointers.
551 
552  computer.Compute(cu_feats, &cu_loglikes);
553  num_frames_consumed += cu_feats.NumRows();
554  ProcessLoglikes(log_inv_prior, &cu_loglikes);
555  }
556 
557  Matrix<BaseFloat> loglikes;
558  loglikes.Swap(&cu_loglikes); // If we don't have a GPU (and not having a
559  // GPU is the normal expected use-case for
560  // this code), this would be a lightweight
561  // operation, swapping pointers.
562 
563 
564  // OK, at this point we may have some newly created log-likes and we want to
565  // give them to the decoding thread.
566 
567  int32 num_loglike_frames = loglikes.NumRows();
568 
569  if (num_loglike_frames != 0) { // if we need to output some loglikes...
570  while (true) {
571  // we may have to grab and release the decodable mutex
572  // a few times before it's ready to accept the loglikes.
574  return false;
575  int32 num_frames_decoded = num_frames_decoded_;
576  // we can't have output fewer frames than were decoded.
577  KALDI_ASSERT(num_frames_output >= num_frames_decoded);
578  if (num_frames_output - num_frames_decoded <= config_.max_loglikes_copy) {
579  // If we would have to copy fewer than config_.max_loglikes_copy
580  // previously output log-likelihoods inside the decodable object, then
581  // we go ahead and copy them to that object.
582  int32 frames_to_discard = num_frames_decoded_ -
584  KALDI_ASSERT(frames_to_discard >= 0);
585  num_frames_output += num_loglike_frames;
586  decodable_.AcceptLoglikes(&loglikes, frames_to_discard);
588  return false;
589  break; // break from the innermost while loop.
590  } else {
591  // There are too many frames already available to the decoder, that it
592  // hasn't processed yet, and we don't want them to have to be copied
593  // inside AcceptLoglikes(), so we wait for a bit.
594  // we want the next call to Lock to block until the decoder has
595  // processed more frames.
597  return false;
598  }
599  }
600  }
601  if (last_time) {
602  // Inform the decodable object that there will be no more input.
604  return false;
607  return false;
608  KALDI_ASSERT(num_frames_consumed == num_frames_output);
609  return true;
610  }
611  }
612 }
virtual int32 Dim() const
Member functions from OnlineFeatureInterface:
virtual bool IsLastFrame(int32 frame) const
Returns true if this is the last frame.
virtual void GetFrame(int32 frame, VectorBase< BaseFloat > *feat)
Gets the feature vector for this frame.
kaldi::int32 int32
OnlineIvectorFeature * IvectorFeature()
This function returns the iVector-extracting part of the feature pipeline (or NULL if iVectors are no...
virtual int32 NumFramesReady() const
returns the feature dimension.
const VectorBase< BaseFloat > & Priors() const
Definition: am-nnet.h:67
void UpdateFrameWeights(const std::vector< std::pair< int32, BaseFloat > > &delta_weights)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void ProcessLoglikes(const CuVector< BaseFloat > &log_inv_prior, CuMatrixBase< BaseFloat > *loglikes)
virtual int32 NumFramesReady() const
returns the feature dimension.
void AcceptLoglikes(Matrix< BaseFloat > *loglikes, int32 frames_to_discard)
const Nnet & GetNnet() const
Definition: am-nnet.h:61
void GetDeltaWeights(int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)

◆ TerminateDecoding()

void TerminateDecoding ( )

You can call this if you don't want the decoding to proceed further with this utterance.

It just won't do any more processing, but you can still use the lattice from the decoding that it's already done. Note: it may still continue decoding up to decode_batch_size (default: 2) frames of data before the decoding thread exits. You can call Wait() after calling this, if you want to wait for that.

Definition at line 215 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::AbortAllThreads().

215  {
216  bool error = false;
217  AbortAllThreads(error);
218 }

◆ Wait()

void Wait ( )

This call will block until all the data has been decoded; it must only be called after either InputFinished() has been called or TerminateDecoding() has been called; otherwise, to call it is an error.

Definition at line 220 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::abort_, SingleUtteranceNnet2DecoderThreaded::input_finished_, KALDI_ERR, and SingleUtteranceNnet2DecoderThreaded::WaitForAllThreads().

220  {
221  if (!input_finished_ && !abort_) {
222  KALDI_ERR << "You cannot call Wait() before calling either InputFinished() "
223  << "or TerminateDecoding().";
224  }
226 }
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ WaitForAllThreads()

void WaitForAllThreads ( )
private

Definition at line 385 of file online-nnet2-decoding-threaded.cc.

References SingleUtteranceNnet2DecoderThreaded::error_, rnnlm::i, KALDI_ERR, and SingleUtteranceNnet2DecoderThreaded::threads_.

Referenced by SingleUtteranceNnet2DecoderThreaded::Wait(), and SingleUtteranceNnet2DecoderThreaded::~SingleUtteranceNnet2DecoderThreaded().

385  {
386  for (int32 i = 0; i < 2; i++) { // there are 2 spawned threads.
387  if (threads_[i].joinable())
388  threads_[i].join();
389  }
390  if (error_)
391  KALDI_ERR << "Error encountered during decoding. See above.";
392 }
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147

Member Data Documentation

◆ abort_

◆ am_nnet_

◆ config_

◆ decodable_

◆ decodable_synchronizer_

◆ decoder_

◆ decoder_mutex_

◆ error_

◆ feature_pipeline_

◆ feature_pipeline_mutex_

◆ input_finished_

◆ input_waveform_

◆ num_frames_decoded_

◆ num_samples_discarded_

◆ num_samples_received_

◆ processed_waveform_

◆ sampling_rate_

◆ silence_weighting_

◆ silence_weighting_mutex_

◆ threads_

◆ tmodel_

◆ waveform_synchronizer_


The documentation for this class was generated from the following files: