This class implements a simplified interface to class NnetBatchComputer, which is suitable for programs like 'nnet3-compute' where you want to support fast GPU-based inference on a sequence of utterances, and get them back from the object in the same order. More...
#include <nnet-batch-compute.h>
Classes | |
struct | UtteranceInfo |
Public Member Functions | |
NnetBatchInference (const NnetBatchComputerOptions &opts, const Nnet &nnet, const VectorBase< BaseFloat > &priors) | |
void | AcceptInput (const std::string &utterance_id, const Matrix< BaseFloat > &input, const Vector< BaseFloat > *ivector, const Matrix< BaseFloat > *online_ivectors, int32 online_ivector_period) |
The user should call this one by one for the utterances that this class needs to compute (interspersed with calls to GetOutput()). More... | |
void | Finished () |
The user should call this after the last input has been provided via AcceptInput(). More... | |
bool | GetOutput (std::string *utterance_id, Matrix< BaseFloat > *output) |
The user should call this to obtain output. More... | |
~NnetBatchInference () | |
Private Member Functions | |
KALDI_DISALLOW_COPY_AND_ASSIGN (NnetBatchInference) | |
void | Compute () |
Static Private Member Functions | |
static void | ComputeFunc (NnetBatchInference *object) |
Private Attributes | |
NnetBatchComputer | computer_ |
bool | is_finished_ |
Semaphore | tasks_ready_semaphore_ |
std::list< UtteranceInfo * > | utts_ |
int32 | utterance_counter_ |
std::thread | compute_thread_ |
This class implements a simplified interface to class NnetBatchComputer, which is suitable for programs like 'nnet3-compute' where you want to support fast GPU-based inference on a sequence of utterances, and get them back from the object in the same order.
Definition at line 502 of file nnet-batch-compute.h.
NnetBatchInference | ( | const NnetBatchComputerOptions & | opts, |
const Nnet & | nnet, | ||
const VectorBase< BaseFloat > & | priors | ||
) |
Definition at line 1083 of file nnet-batch-compute.cc.
References NnetBatchInference::compute_thread_, and NnetBatchInference::ComputeFunc().
~NnetBatchInference | ( | ) |
Definition at line 1157 of file nnet-batch-compute.cc.
References NnetBatchInference::compute_thread_, NnetBatchInference::is_finished_, KALDI_ERR, and NnetBatchInference::utts_.
void AcceptInput | ( | const std::string & | utterance_id, |
const Matrix< BaseFloat > & | input, | ||
const Vector< BaseFloat > * | ivector, | ||
const Matrix< BaseFloat > * | online_ivectors, | ||
int32 | online_ivector_period | ||
) |
The user should call this one by one for the utterances that this class needs to compute (interspersed with calls to GetOutput()).
This call will block when enough ready-to-be-computed data is present.
[in] | utterance_id | The string representing the utterance-id; it will be provided back to the user when GetOutput() is called. |
[in] | input | The input features (e.g. MFCCs) |
[in] | ivector | If non-NULL, this is expected to be the i-vector for this utterance (and 'online_ivectors' should be NULL). |
[in] | online_ivector_period | Only relevant if 'online_ivector' is non-NULL, this says how many frames of 'input' is covered by each row of 'online_ivectors'. |
Definition at line 1095 of file nnet-batch-compute.cc.
References NnetBatchComputer::AcceptTask(), NnetBatchInference::computer_, rnnlm::i, NnetBatchInference::UtteranceInfo::num_tasks_finished, Semaphore::Signal(), NnetBatchComputer::SplitUtteranceIntoTasks(), NnetBatchInference::UtteranceInfo::tasks, NnetBatchInference::tasks_ready_semaphore_, NnetBatchInference::utterance_counter_, NnetBatchInference::UtteranceInfo::utterance_id, and NnetBatchInference::utts_.
Referenced by main().
|
private |
Definition at line 1171 of file nnet-batch-compute.cc.
References NnetBatchComputer::Compute(), NnetBatchInference::computer_, NnetBatchInference::is_finished_, NnetBatchInference::tasks_ready_semaphore_, and Semaphore::Wait().
|
inlinestaticprivate |
Definition at line 565 of file nnet-batch-compute.h.
Referenced by NnetBatchInference::NnetBatchInference().
void Finished | ( | ) |
The user should call this after the last input has been provided via AcceptInput().
This will force the last utterances to be flushed out (to be retrieved by GetOutput()), rather than waiting until the relevant minibatches are full.
Definition at line 1165 of file nnet-batch-compute.cc.
References NnetBatchInference::is_finished_, Semaphore::Signal(), and NnetBatchInference::tasks_ready_semaphore_.
Referenced by main().
The user should call this to obtain output.
It's guaranteed to be in the same order as the input was provided, but it may be delayed. 'output' will be the output of the neural net, spliced together over the chunks (and with acoustic scaling applied if it was specified in the options; the subtraction of priors will depend whether you supplied a non-empty vector of priors to the constructor.
This call does not block (i.e. does not wait on any semaphores) unless you have previously called Finished(). It returns true if it actually got any output; if none was ready it will return false.
Definition at line 1130 of file nnet-batch-compute.cc.
References NnetBatchInference::is_finished_, kaldi::nnet3::MergeTaskOutput(), NnetBatchInference::UtteranceInfo::num_tasks_finished, NnetBatchInference::UtteranceInfo::tasks, Semaphore::TryWait(), NnetBatchInference::UtteranceInfo::utterance_id, NnetBatchInference::utts_, and Semaphore::Wait().
Referenced by main().
|
private |
|
private |
Definition at line 602 of file nnet-batch-compute.h.
Referenced by NnetBatchInference::NnetBatchInference(), and NnetBatchInference::~NnetBatchInference().
|
private |
Definition at line 572 of file nnet-batch-compute.h.
Referenced by NnetBatchInference::AcceptInput(), and NnetBatchInference::Compute().
|
private |
Definition at line 576 of file nnet-batch-compute.h.
Referenced by NnetBatchInference::Compute(), NnetBatchInference::Finished(), NnetBatchInference::GetOutput(), and NnetBatchInference::~NnetBatchInference().
|
private |
Definition at line 581 of file nnet-batch-compute.h.
Referenced by NnetBatchInference::AcceptInput(), NnetBatchInference::Compute(), and NnetBatchInference::Finished().
|
private |
Definition at line 599 of file nnet-batch-compute.h.
Referenced by NnetBatchInference::AcceptInput().
|
private |
Definition at line 597 of file nnet-batch-compute.h.
Referenced by NnetBatchInference::AcceptInput(), NnetBatchInference::GetOutput(), and NnetBatchInference::~NnetBatchInference().