30 producer_waiting_(false),
31 consumer_waiting_(false),
116 const fst::Fst<fst::StdArc> &
fst,
120 config_(config), am_nnet_(am_nnet), tmodel_(tmodel), sampling_rate_(0.0),
121 num_samples_received_(0), input_finished_(false),
122 feature_pipeline_(feature_info),
123 num_samples_discarded_(0),
124 silence_weighting_(tmodel, feature_info.silence_weighting_config),
126 num_frames_decoded_(0), decoder_(fst, config_.decoder_opts),
127 abort_(false), error_(false) {
170 if (wave_part.
Dim() == 0)
return;
172 KALDI_ERR <<
"Failure locking mutex: decoding aborted.";
190 KALDI_ERR <<
"Failure locking mutex: decoding aborted.";
208 KALDI_ERR <<
"Failure locking mutex: decoding aborted.";
222 KALDI_ERR <<
"You cannot call Wait() before calling either InputFinished() " 223 <<
"or TerminateDecoding().";
230 KALDI_ERR <<
"It is an error to call FinalizeDecoding before Wait().";
238 KALDI_ERR <<
"It is an error to call GetRemainingWaveform before Wait().";
240 int64 num_samples_stored = 0;
241 std::vector< Vector<BaseFloat>* > all_pieces;
242 std::deque< Vector<BaseFloat>* >::const_iterator iter;
245 num_samples_stored += (*iter)->Dim();
246 all_pieces.push_back(*iter);
249 num_samples_stored += (*iter)->Dim();
250 all_pieces.push_back(*iter);
252 int64 samples_shift_per_frame =
260 num_samp_keep = num_samples_stored - num_samp_discard;
261 KALDI_ASSERT(num_samp_discard <= num_samples_stored && num_samp_keep >= 0);
265 for (
size_t i = 0;
i < all_pieces.size();
i++) {
268 if (num_samp_discard >= this_dim) {
269 num_samp_discard -= this_dim;
272 int32 this_dim_keep = this_dim - num_samp_discard;
273 waveform->
Range(offset, this_dim_keep).CopyFromVec(
274 this_piece->
Range(num_samp_discard, this_dim_keep));
275 offset += this_dim_keep;
276 num_samp_discard = 0;
279 KALDI_ASSERT(offset == num_samp_keep && num_samp_discard == 0);
298 bool end_of_utterance,
301 clat->DeleteStates();
303 if (final_relative_cost != NULL)
307 clat->SetFinal(clat->AddState(),
316 KALDI_ERR <<
"--determinize-lattice=false option is not supported at the moment";
324 bool end_of_utterance,
331 best_path->DeleteStates();
332 best_path->SetFinal(best_path->AddState(),
334 if (final_relative_cost != NULL)
335 *final_relative_cost = std::numeric_limits<BaseFloat>::infinity();
339 if (final_relative_cost != NULL)
361 KALDI_ERR <<
"Returned abnormally and abort was not called";
362 }
catch(
const std::exception &e) {
363 KALDI_WARN <<
"Caught exception: " << e.what();
375 KALDI_ERR <<
"Returned abnormally and abort was not called";
376 }
catch(
const std::exception &e) {
377 KALDI_WARN <<
"Caught exception: " << e.what();
391 KALDI_ERR <<
"Error encountered during decoding. See above.";
398 if (cu_loglikes->
NumRows() != 0) {
413 int32 num_frames_consumed) {
416 num_frames_usable = num_frames_ready - num_frames_consumed;
451 num_frames_usable = num_frames_ready - num_frames_consumed;
456 int32 samples_shift_per_frame =
476 bool pad_input =
true;
483 log_inv_prior.ApplyLog();
484 log_inv_prior.Scale(-1.0);
491 int32 num_frames_consumed = 0, num_frames_output = 0;
494 bool last_time =
false;
506 std::vector<std::pair<int32, BaseFloat> > delta_weights;
515 num_frames_usable = num_frames_ready - num_frames_consumed;
518 int32 num_frames_evaluate = std::min<int32>(num_frames_usable,
522 if (num_frames_evaluate > 0) {
525 for (
int32 i = 0;
i < num_frames_evaluate;
i++) {
526 int32 t = num_frames_consumed +
i;
536 if (feats.NumRows() == 0) {
542 computer.Flush(&cu_loglikes);
547 cu_feats.
Swap(&feats);
552 computer.Compute(cu_feats, &cu_loglikes);
553 num_frames_consumed += cu_feats.
NumRows();
558 loglikes.
Swap(&cu_loglikes);
569 if (num_loglike_frames != 0) {
585 num_frames_output += num_loglike_frames;
616 int32 num_frames_decoded = 0;
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
const nnet2::AmNnet & am_nnet_
OnlineSilenceWeighting silence_weighting_
virtual int32 NumFramesReady() const
The call NumFramesReady() will return the number of frames currently available for this decodable obj...
bool EndpointDetected(const OnlineEndpointConfig &config)
This function calls EndpointDetected from online-endpoint.h, with the required arguments.
virtual int32 Dim() const
Member functions from OnlineFeatureInterface:
bool FeatureComputation(int32 num_frames_output)
LatticeFasterDecoderConfig decoder_opts
ThreadSynchronizer waveform_synchronizer_
bool GetRawLattice(Lattice *ofst, bool use_final_probs=true) const
Outputs an FST corresponding to the raw, state-level tracebacks.
int32 num_frames_decoded_
std::deque< Vector< BaseFloat > *> input_waveform_
static const LatticeWeightTpl One()
static void RunDecoderSearch(SingleUtteranceNnet2DecoderThreaded *me)
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
Semaphore producer_semaphore_
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void GetBestPath(bool end_of_utterance, Lattice *best_path, BaseFloat *final_relative_cost) const
Outputs an FST corresponding to the single best path through the current lattice. ...
void GetAdaptationState(OnlineIvectorExtractorAdaptationState *adaptation_state) const
Get the adaptation state; you may want to call this before destroying this object, to get adaptation state that can be used to improve decoding of later utterances of this speaker.
OnlineNnet2FeaturePipeline feature_pipeline_
std::mutex decoder_mutex_
void Signal()
increase the counter
const TransitionModel & tmodel_
virtual bool IsLastFrame(int32 frame) const
Returns true if this is the last frame.
void AbortAllThreads(bool error)
void ApplyFloor(Real floor_val)
virtual void GetFrame(int32 frame, VectorBase< BaseFloat > *feat)
Gets the feature vector for this frame.
void InputFinished()
If you call InputFinished(), it tells the class you won't be providing any more waveform.
int32 NumFramesReceivedApprox() const
Returns *approximately* (ignoring end effects), the number of frames of data that we expect given the...
void FinalizeDecoding()
This function may be optionally called after AdvanceDecoding(), when you do not plan to decode any fu...
bool EndpointDetected(const OnlineEndpointConfig &config, int32 num_frames_decoded, int32 trailing_silence_frames, BaseFloat frame_shift_in_seconds, BaseFloat final_relative_cost)
This function returns true if this set of endpointing rules thinks we should terminate decoding...
BaseFloat FrameShiftInSeconds() const
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
std::mutex feature_pipeline_mutex_
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Swap(Matrix< Real > *other)
Swaps the contents of *this and *other. Shallow swap.
OnlineIvectorFeature * IvectorFeature()
This function returns the iVector-extracting part of the feature pipeline (or NULL if iVectors are no...
void FinalizeDecoding()
Finalizes the decoding.
virtual int32 NumFramesReady() const
returns the feature dimension.
void ApplyFloor(Real floor_val, MatrixIndexT *floored_count=NULL)
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
bool UnlockFailure(ThreadType t)
ThreadSynchronizer decodable_synchronizer_
void InputFinished()
You call this to inform the class that no more waveform will be provided; this allows it to flush out...
BaseFloat GetRemainingWaveform(Vector< BaseFloat > *waveform_out) const
Gets the remaining, un-decoded part of the waveform and returns the sample rate.
std::deque< Vector< BaseFloat > *> processed_waveform_
void AcceptWaveform(BaseFloat sampling_rate, const VectorBase< BaseFloat > &waveform)
Accept more data to process.
void InitDecoding()
InitDecoding initializes the decoding, and should only be used if you intend to call AdvanceDecoding(...
int32 NumWaveformPiecesPending()
Returns the number of pieces of waveform that are still waiting to be processed.
void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl< FST > &decoder)
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
bool RunDecoderSearchInternal()
static const CompactLatticeWeightTpl< WeightType, IntType > One()
void Swap(Matrix< Real > *mat)
int32 NumFramesDecoded() const
const VectorBase< BaseFloat > & Priors() const
OnlineNnet2DecodingThreadedConfig config_
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
bool UnlockSuccess(ThreadType t)
bool RunNnetEvaluationInternal()
SingleUtteranceNnet2DecoderThreaded(const OnlineNnet2DecodingThreadedConfig &config, const TransitionModel &tmodel, const nnet2::AmNnet &am_nnet, const fst::Fst< fst::StdArc > &fst, const OnlineNnet2FeaturePipelineInfo &feature_info, const OnlineIvectorExtractorAdaptationState &adaptation_state, const OnlineCmvnState &cmvn_state)
fst::VectorFst< LatticeArc > Lattice
void TerminateDecoding()
You can call this if you don't want the decoding to proceed further with this utterance.
int32 FirstAvailableFrame() const
void GetAdaptationState(OnlineIvectorExtractorAdaptationState *adaptation_state)
Outputs the adaptation state of the feature pipeline to "adaptation_state".
MatrixIndexT Dim() const
Returns the dimension of the vector.
std::mutex silence_weighting_mutex_
int64 num_samples_received_
void GetCmvnState(OnlineCmvnState *cmvn_state)
void SetAdaptationState(const OnlineIvectorExtractorAdaptationState &adaptation_state)
Set the adaptation state to a particular value, e.g.
LatticeFasterOnlineDecoder decoder_
fst::VectorFst< CompactLatticeArc > CompactLattice
void SetCmvnState(const OnlineCmvnState &cmvn_state)
Set the CMVN state to a particular value.
DecodableMatrixMappedOffset decodable_
fst::DeterminizeLatticePhonePrunedOptions det_opts
void UpdateFrameWeights(const std::vector< std::pair< int32, BaseFloat > > &delta_weights)
Matrix for CUDA computing.
BaseFloat FinalRelativeCost() const
FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives more information.
void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
This will decode until there are no more frames ready in the decodable object.
A class representing a vector.
#define KALDI_ASSERT(cond)
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
virtual bool IsLastFrame(int32 frame) const
Returns true if this is the last frame.
int64 num_samples_discarded_
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void ProcessLoglikes(const CuVector< BaseFloat > &log_inv_prior, CuMatrixBase< BaseFloat > *loglikes)
virtual int32 NumFramesReady() const
returns the feature dimension.
void Wait()
This call will block until all the data has been decoded; it must only be called after either InputFi...
bool GetBestPath(Lattice *ofst, bool use_final_probs=true) const
Outputs an FST corresponding to the single best path through the lattice.
MatrixIndexT NumRows() const
Dimensions.
void GetLattice(bool end_of_utterance, CompactLattice *clat, BaseFloat *final_relative_cost) const
Gets the lattice.
Provides a vector abstraction class.
void AcceptWaveform(BaseFloat samp_freq, const VectorBase< BaseFloat > &wave_part)
You call this to provide this class with more waveform to decode.
static void RunNnetEvaluation(SingleUtteranceNnet2DecoderThreaded *me)
Semaphore consumer_semaphore_
~SingleUtteranceNnet2DecoderThreaded()
void GetCmvnState(OnlineCmvnState *cmvn_state)
Outputs the OnlineCmvnState of the feature pipeline to "cmvn_stat".
void AcceptLoglikes(Matrix< BaseFloat > *loglikes, int32 frames_to_discard)
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
bool DeterminizeLatticePhonePrunedWrapper(const kaldi::TransitionModel &trans_model, MutableFst< kaldi::LatticeArc > *ifst, double beam, MutableFst< kaldi::CompactLatticeArc > *ofst, DeterminizeLatticePhonePrunedOptions opts)
This function is a wrapper of DeterminizeLatticePhonePruned() that works for Lattice type FSTs...
void Wait()
decrease the counter
const Nnet & GetNnet() const
int32 NumFramesDecoded() const
Returns the number of frames currently decoded.
void GetDeltaWeights(int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)
SubVector< Real > Range(const MatrixIndexT o, const MatrixIndexT l)
Returns a sub-vector of a vector (a range of elements).