OnlineSilenceWeighting Class Reference

#include <online-ivector-feature.h>

Collaboration diagram for OnlineSilenceWeighting:

Classes

struct  FrameInfo
 

Public Member Functions

 OnlineSilenceWeighting (const TransitionModel &trans_model, const OnlineSilenceWeightingConfig &config, int32 frame_subsampling_factor=1)
 
bool Active () const
 
template<typename FST >
void ComputeCurrentTraceback (const LatticeFasterOnlineDecoderTpl< FST > &decoder)
 
template<typename FST >
void ComputeCurrentTraceback (const LatticeIncrementalOnlineDecoderTpl< FST > &decoder)
 
void GetDeltaWeights (int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)
 
void GetDeltaWeights (int32 num_frames_ready, std::vector< std::pair< int32, BaseFloat > > *delta_weights)
 

Private Attributes

const TransitionModeltrans_model_
 
const OnlineSilenceWeightingConfigconfig_
 
int32 frame_subsampling_factor_
 
unordered_set< int32silence_phones_
 
std::vector< FrameInfoframe_info_
 
int32 num_frames_output_and_correct_
 

Detailed Description

Definition at line 465 of file online-ivector-feature.h.

Constructor & Destructor Documentation

◆ OnlineSilenceWeighting()

OnlineSilenceWeighting ( const TransitionModel trans_model,
const OnlineSilenceWeightingConfig config,
int32  frame_subsampling_factor = 1 
)

Definition at line 465 of file online-ivector-feature.cc.

References OnlineSilenceWeighting::frame_subsampling_factor_, rnnlm::i, KALDI_ASSERT, OnlineSilenceWeighting::silence_phones_, OnlineSilenceWeightingConfig::silence_phones_str, and kaldi::SplitStringToIntegers().

468  :
469  trans_model_(trans_model), config_(config),
470  frame_subsampling_factor_(frame_subsampling_factor),
473  std::vector<int32> silence_phones;
474  SplitStringToIntegers(config.silence_phones_str, ":,", false,
475  &silence_phones);
476  for (size_t i = 0; i < silence_phones.size(); i++)
477  silence_phones_.insert(silence_phones[i]);
478 }
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
const OnlineSilenceWeightingConfig & config_
unordered_set< int32 > silence_phones_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const TransitionModel & trans_model_

Member Function Documentation

◆ Active()

◆ ComputeCurrentTraceback() [1/2]

void ComputeCurrentTraceback ( const LatticeFasterOnlineDecoderTpl< FST > &  decoder)

Definition at line 482 of file online-ivector-feature.cc.

References LatticeFasterOnlineDecoderTpl< FST >::BestPathEnd(), LatticeFasterOnlineDecoderTpl< FST >::BestPathIterator::frame, OnlineSilenceWeighting::frame_info_, KALDI_ASSERT, KALDI_ERR, OnlineSilenceWeighting::num_frames_output_and_correct_, LatticeFasterDecoderTpl< FST, decoder::BackpointerToken >::NumFramesDecoded(), LatticeFasterOnlineDecoderTpl< FST >::BestPathIterator::tok, and LatticeFasterOnlineDecoderTpl< FST >::TraceBackBestPath().

Referenced by main(), and SingleUtteranceNnet2DecoderThreaded::RunDecoderSearchInternal().

483  {
484  int32 num_frames_decoded = decoder.NumFramesDecoded(),
485  num_frames_prev = frame_info_.size();
486  // note, num_frames_prev is not the number of frames previously decoded,
487  // it's the generally-larger number of frames that we were requested to
488  // provide weights for.
489  if (num_frames_prev < num_frames_decoded)
490  frame_info_.resize(num_frames_decoded);
491  if (num_frames_prev > num_frames_decoded &&
492  frame_info_[num_frames_decoded].transition_id != -1)
493  KALDI_ERR << "Number of frames decoded decreased"; // Likely bug
494 
495  if (num_frames_decoded == 0)
496  return;
497  int32 frame = num_frames_decoded - 1;
498  bool use_final_probs = false;
499  typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator iter =
500  decoder.BestPathEnd(use_final_probs, NULL);
501  while (frame >= 0) {
502  LatticeArc arc;
503  arc.ilabel = 0;
504  while (arc.ilabel == 0) // the while loop skips over input-epsilons
505  iter = decoder.TraceBackBestPath(iter, &arc);
506  // note, the iter.frame values are slightly unintuitively defined,
507  // they are one less than you might expect.
508  KALDI_ASSERT(iter.frame == frame - 1);
509 
510  if (frame_info_[frame].token == iter.tok) {
511  // we know that the traceback from this point back will be identical, so
512  // no point tracing back further. Note: we are comparing memory addresses
513  // of tokens of the decoder; this guarantees it's the same exact token
514  // because tokens, once allocated on a frame, are only deleted, never
515  // reallocated for that frame.
516  break;
517  }
518 
519  if (num_frames_output_and_correct_ > frame)
521 
522  frame_info_[frame].token = iter.tok;
523  frame_info_[frame].transition_id = arc.ilabel;
524  frame--;
525  // leave frame_info_.current_weight at zero for now (as set in the
526  // constructor), reflecting that we haven't already output a weight for that
527  // frame.
528  }
529 }
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< FrameInfo > frame_info_

◆ ComputeCurrentTraceback() [2/2]

void ComputeCurrentTraceback ( const LatticeIncrementalOnlineDecoderTpl< FST > &  decoder)

Definition at line 532 of file online-ivector-feature.cc.

References LatticeIncrementalOnlineDecoderTpl< FST >::BestPathEnd(), LatticeIncrementalOnlineDecoderTpl< FST >::BestPathIterator::frame, OnlineSilenceWeighting::frame_info_, KALDI_ASSERT, KALDI_ERR, OnlineSilenceWeighting::num_frames_output_and_correct_, LatticeIncrementalDecoderTpl< FST, decoder::BackpointerToken >::NumFramesDecoded(), LatticeIncrementalOnlineDecoderTpl< FST >::BestPathIterator::tok, and LatticeIncrementalOnlineDecoderTpl< FST >::TraceBackBestPath().

533  {
534  int32 num_frames_decoded = decoder.NumFramesDecoded(),
535  num_frames_prev = frame_info_.size();
536  // note, num_frames_prev is not the number of frames previously decoded,
537  // it's the generally-larger number of frames that we were requested to
538  // provide weights for.
539  if (num_frames_prev < num_frames_decoded)
540  frame_info_.resize(num_frames_decoded);
541  if (num_frames_prev > num_frames_decoded &&
542  frame_info_[num_frames_decoded].transition_id != -1)
543  KALDI_ERR << "Number of frames decoded decreased"; // Likely bug
544 
545  if (num_frames_decoded == 0)
546  return;
547  int32 frame = num_frames_decoded - 1;
548  bool use_final_probs = false;
549  typename LatticeIncrementalOnlineDecoderTpl<FST>::BestPathIterator iter =
550  decoder.BestPathEnd(use_final_probs, NULL);
551  while (frame >= 0) {
552  LatticeArc arc;
553  arc.ilabel = 0;
554  while (arc.ilabel == 0) // the while loop skips over input-epsilons
555  iter = decoder.TraceBackBestPath(iter, &arc);
556  // note, the iter.frame values are slightly unintuitively defined,
557  // they are one less than you might expect.
558  KALDI_ASSERT(iter.frame == frame - 1);
559 
560  if (frame_info_[frame].token == iter.tok) {
561  // we know that the traceback from this point back will be identical, so
562  // no point tracing back further. Note: we are comparing memory addresses
563  // of tokens of the decoder; this guarantees it's the same exact token,
564  // because tokens, once allocated on a frame, are only deleted, never
565  // reallocated for that frame.
566  break;
567  }
568 
569  if (num_frames_output_and_correct_ > frame)
571 
572  frame_info_[frame].token = iter.tok;
573  frame_info_[frame].transition_id = arc.ilabel;
574  frame--;
575  // leave frame_info_.current_weight at zero for now (as set in the
576  // constructor), reflecting that we haven't already output a weight for that
577  // frame.
578  }
579 }
fst::ArcTpl< LatticeWeight > LatticeArc
Definition: kaldi-lattice.h:40
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< FrameInfo > frame_info_

◆ GetDeltaWeights() [1/2]

void GetDeltaWeights ( int32  num_frames_ready,
int32  first_decoder_frame,
std::vector< std::pair< int32, BaseFloat > > *  delta_weights 
)

Definition at line 597 of file online-ivector-feature.cc.

References OnlineSilenceWeighting::config_, OnlineSilenceWeighting::frame_info_, OnlineSilenceWeighting::frame_subsampling_factor_, rnnlm::i, KALDI_ASSERT, KALDI_VLOG, OnlineSilenceWeightingConfig::max_state_duration, OnlineSilenceWeighting::silence_phones_, OnlineSilenceWeightingConfig::silence_weight, OnlineSilenceWeighting::trans_model_, and TransitionModel::TransitionIdToPhone().

Referenced by main(), and SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal().

599  {
600  // num_frames_ready is at the feature frame-rate, most of the code
601  // in this function is at the decoder frame-rate.
602  // round up, so we are sure to get weights for at least the frame
603  // 'num_frames_ready - 1', and maybe one or two frames afterward.
604  KALDI_ASSERT(num_frames_ready > first_decoder_frame || num_frames_ready == 0);
606  num_decoder_frames_ready = (num_frames_ready - first_decoder_frame + fs - 1) / fs;
607 
608  const int32 max_state_duration = config_.max_state_duration;
609  const BaseFloat silence_weight = config_.silence_weight;
610 
611  delta_weights->clear();
612 
613  int32 prev_num_frames_processed = frame_info_.size();
614  if (frame_info_.size() < static_cast<size_t>(num_decoder_frames_ready))
615  frame_info_.resize(num_decoder_frames_ready);
616 
617  // Don't go further backward into the past then 100 frames before the most
618  // recent frame previously than 100 frames when modifying the traceback.
619  // C.f. the value 200 in template
620  // OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature in online-feature.cc,
621  // which needs to be more than this value of 100 plus the amount of context
622  // that LDA might use plus the chunk size we're likely to decode in one time.
623  // The user can always increase the value of --max-feature-vectors in case one
624  // of these conditions is broken. Search for ONLINE_IVECTOR_LIMIT in
625  // online-feature.cc
626  int32 begin_frame = std::max<int32>(0, prev_num_frames_processed - 100),
627  frames_out = static_cast<int32>(frame_info_.size()) - begin_frame;
628  // frames_out is the number of frames we will output.
629  KALDI_ASSERT(frames_out >= 0);
630  std::vector<BaseFloat> frame_weight(frames_out, 1.0);
631  // we will set frame_weight to the value silence_weight for silence frames and
632  // for transition-ids that repeat with duration > max_state_duration. Frames
633  // newer than the most recent traceback will get a weight equal to the weight
634  // for the most recent frame in the traceback; or the silence weight, if there
635  // is no traceback at all available yet.
636 
637  // First treat some special cases.
638  if (frames_out == 0) // Nothing to output.
639  return;
640  if (frame_info_[begin_frame].transition_id == -1) {
641  // We do not have any traceback at all within the frames we are to output...
642  // find the most recent weight that we output and apply the same weight to
643  // all the new output; or output the silence weight, if nothing was output.
644  BaseFloat weight = (begin_frame == 0 ? silence_weight :
645  frame_info_[begin_frame - 1].current_weight);
646  for (int32 offset = 0; offset < frames_out; offset++)
647  frame_weight[offset] = weight;
648  } else {
649  int32 current_run_start_offset = 0;
650  for (int32 offset = 0; offset < frames_out; offset++) {
651  int32 frame = begin_frame + offset;
652  int32 transition_id = frame_info_[frame].transition_id;
653  if (transition_id == -1) {
654  // this frame does not yet have a decoder traceback, so just
655  // duplicate the silence/non-silence status of the most recent
656  // frame we have a traceback for (probably a reasonable guess).
657  frame_weight[offset] = frame_weight[offset - 1];
658  } else {
659  int32 phone = trans_model_.TransitionIdToPhone(transition_id);
660  bool is_silence = (silence_phones_.count(phone) != 0);
661  if (is_silence)
662  frame_weight[offset] = silence_weight;
663  // now deal with max-duration issues.
664  if (max_state_duration > 0 &&
665  (offset + 1 == frames_out ||
666  transition_id != frame_info_[frame + 1].transition_id)) {
667  // If this is the last frame of a run...
668  int32 run_length = offset - current_run_start_offset + 1;
669  if (run_length >= max_state_duration) {
670  // treat runs of the same transition-id longer than the max, as
671  // silence, even if they were not silence.
672  for (int32 offset2 = current_run_start_offset;
673  offset2 <= offset; offset2++)
674  frame_weight[offset2] = silence_weight;
675  }
676  if (offset + 1 < frames_out)
677  current_run_start_offset = offset + 1;
678  }
679  }
680  }
681  }
682  // Now commit the stats...
683  for (int32 offset = 0; offset < frames_out; offset++) {
684  int32 frame = begin_frame + offset;
685  BaseFloat old_weight = frame_info_[frame].current_weight,
686  new_weight = frame_weight[offset],
687  weight_diff = new_weight - old_weight;
688  frame_info_[frame].current_weight = new_weight;
689  // Even if the delta-weight is zero for the last frame, we provide it,
690  // because the identity of the most recent frame with a weight is used in
691  // some debugging/checking code.
692  if (weight_diff != 0.0 || offset + 1 == frames_out) {
693  KALDI_VLOG(6) << "Weight for frame " << frame << " changing from "
694  << old_weight << " to " << new_weight;
695  for(int32 i = 0; i < frame_subsampling_factor_; i++) {
696  int32 input_frame = first_decoder_frame + (frame * frame_subsampling_factor_) + i;
697  delta_weights->push_back(std::make_pair(input_frame, weight_diff));
698  }
699  }
700  }
701 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
const OnlineSilenceWeightingConfig & config_
unordered_set< int32 > silence_phones_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
const TransitionModel & trans_model_
int32 TransitionIdToPhone(int32 trans_id) const
std::vector< FrameInfo > frame_info_

◆ GetDeltaWeights() [2/2]

void GetDeltaWeights ( int32  num_frames_ready,
std::vector< std::pair< int32, BaseFloat > > *  delta_weights 
)
inline

Definition at line 519 of file online-ivector-feature.h.

521  {
522  GetDeltaWeights(num_frames_ready, 0, delta_weights);
523  }
void GetDeltaWeights(int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)

Member Data Documentation

◆ config_

const OnlineSilenceWeightingConfig& config_
private

Definition at line 527 of file online-ivector-feature.h.

Referenced by OnlineSilenceWeighting::GetDeltaWeights().

◆ frame_info_

std::vector<FrameInfo> frame_info_
private

◆ frame_subsampling_factor_

int32 frame_subsampling_factor_
private

◆ num_frames_output_and_correct_

int32 num_frames_output_and_correct_
private

◆ silence_phones_

unordered_set<int32> silence_phones_
private

◆ trans_model_

const TransitionModel& trans_model_
private

Definition at line 526 of file online-ivector-feature.h.

Referenced by OnlineSilenceWeighting::GetDeltaWeights().


The documentation for this class was generated from the following files: