NnetBatchInference Class Reference

This class implements a simplified interface to class NnetBatchComputer, which is suitable for programs like 'nnet3-compute' where you want to support fast GPU-based inference on a sequence of utterances, and get them back from the object in the same order. More...

#include <nnet-batch-compute.h>

Collaboration diagram for NnetBatchInference:

Classes

struct  UtteranceInfo
 

Public Member Functions

 NnetBatchInference (const NnetBatchComputerOptions &opts, const Nnet &nnet, const VectorBase< BaseFloat > &priors)
 
void AcceptInput (const std::string &utterance_id, const Matrix< BaseFloat > &input, const Vector< BaseFloat > *ivector, const Matrix< BaseFloat > *online_ivectors, int32 online_ivector_period)
 The user should call this one by one for the utterances that this class needs to compute (interspersed with calls to GetOutput()). More...
 
void Finished ()
 The user should call this after the last input has been provided via AcceptInput(). More...
 
bool GetOutput (std::string *utterance_id, Matrix< BaseFloat > *output)
 The user should call this to obtain output. More...
 
 ~NnetBatchInference ()
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (NnetBatchInference)
 
void Compute ()
 

Static Private Member Functions

static void ComputeFunc (NnetBatchInference *object)
 

Private Attributes

NnetBatchComputer computer_
 
bool is_finished_
 
Semaphore tasks_ready_semaphore_
 
std::list< UtteranceInfo * > utts_
 
int32 utterance_counter_
 
std::thread compute_thread_
 

Detailed Description

This class implements a simplified interface to class NnetBatchComputer, which is suitable for programs like 'nnet3-compute' where you want to support fast GPU-based inference on a sequence of utterances, and get them back from the object in the same order.

Definition at line 502 of file nnet-batch-compute.h.

Constructor & Destructor Documentation

◆ NnetBatchInference()

NnetBatchInference ( const NnetBatchComputerOptions opts,
const Nnet nnet,
const VectorBase< BaseFloat > &  priors 
)

Definition at line 1083 of file nnet-batch-compute.cc.

References NnetBatchInference::compute_thread_, and NnetBatchInference::ComputeFunc().

1086  :
1087  computer_(opts, nnet, priors),
1088  is_finished_(false),
1089  utterance_counter_(0) {
1090  // 'thread_' will run the Compute() function in the background.
1091  compute_thread_ = std::thread(ComputeFunc, this);
1092 }
static void ComputeFunc(NnetBatchInference *object)

◆ ~NnetBatchInference()

Definition at line 1157 of file nnet-batch-compute.cc.

References NnetBatchInference::compute_thread_, NnetBatchInference::is_finished_, KALDI_ERR, and NnetBatchInference::utts_.

1157  {
1158  if (!is_finished_)
1159  KALDI_ERR << "Object destroyed before Finished() was called.";
1160  if (!utts_.empty())
1161  KALDI_ERR << "You should get all output before destroying this object.";
1162  compute_thread_.join();
1163 }
#define KALDI_ERR
Definition: kaldi-error.h:147
std::list< UtteranceInfo * > utts_

Member Function Documentation

◆ AcceptInput()

void AcceptInput ( const std::string &  utterance_id,
const Matrix< BaseFloat > &  input,
const Vector< BaseFloat > *  ivector,
const Matrix< BaseFloat > *  online_ivectors,
int32  online_ivector_period 
)

The user should call this one by one for the utterances that this class needs to compute (interspersed with calls to GetOutput()).

This call will block when enough ready-to-be-computed data is present.

Parameters
[in]utterance_idThe string representing the utterance-id; it will be provided back to the user when GetOutput() is called.
[in]inputThe input features (e.g. MFCCs)
[in]ivectorIf non-NULL, this is expected to be the i-vector for this utterance (and 'online_ivectors' should be NULL).
[in]online_ivector_periodOnly relevant if 'online_ivector' is non-NULL, this says how many frames of 'input' is covered by each row of 'online_ivectors'.

Definition at line 1095 of file nnet-batch-compute.cc.

References NnetBatchComputer::AcceptTask(), NnetBatchInference::computer_, rnnlm::i, NnetBatchInference::UtteranceInfo::num_tasks_finished, Semaphore::Signal(), NnetBatchComputer::SplitUtteranceIntoTasks(), NnetBatchInference::UtteranceInfo::tasks, NnetBatchInference::tasks_ready_semaphore_, NnetBatchInference::utterance_counter_, NnetBatchInference::UtteranceInfo::utterance_id, and NnetBatchInference::utts_.

Referenced by main().

1100  {
1101 
1102  UtteranceInfo *info = new UtteranceInfo();
1103  info->utterance_id = utterance_id;
1104  info->num_tasks_finished = 0;
1105  bool output_to_cpu = true; // This wrapper is for when you need the nnet
1106  // output on CPU, e.g. because you want it
1107  // written to disk. If this needs to be
1108  // configurable in the future, we can make changes
1109  // then.
1111  output_to_cpu, input, ivector, online_ivectors,
1112  online_ivector_period, &(info->tasks));
1113 
1114  // Setting this to a nonzero value will cause the AcceptTask() call below to
1115  // hang until the computation thread has made some progress, if too much
1116  // data is already queued.
1117  int32 max_full_minibatches = 2;
1118 
1119  // Earlier utterances have higher priority, which is important to make sure
1120  // that their corresponding tasks are completed and they can be output to disk.
1121  double priority = -1.0 * (utterance_counter_++);
1122  for (size_t i = 0; i < info->tasks.size(); i++) {
1123  info->tasks[i].priority = priority;
1124  computer_.AcceptTask(&(info->tasks[i]), max_full_minibatches);
1125  }
1126  utts_.push_back(info);
1128 }
void Signal()
increase the counter
kaldi::int32 int32
void SplitUtteranceIntoTasks(bool output_to_cpu, const Matrix< BaseFloat > &input, const Vector< BaseFloat > *ivector, const Matrix< BaseFloat > *online_ivectors, int32 online_ivector_period, std::vector< NnetInferenceTask > *tasks)
Split a single utterance into a list of separate tasks which can then be given to this class by Accep...
void AcceptTask(NnetInferenceTask *task, int32 max_minibatches_full=-1)
Accepts a task, meaning the task will be queued.
std::list< UtteranceInfo * > utts_

◆ Compute()

void Compute ( )
private

Definition at line 1171 of file nnet-batch-compute.cc.

References NnetBatchComputer::Compute(), NnetBatchInference::computer_, NnetBatchInference::is_finished_, NnetBatchInference::tasks_ready_semaphore_, and Semaphore::Wait().

1171  {
1172  bool allow_partial_minibatch = false;
1173  while (true) {
1174  // keep calling Compute() as long as it makes progress.
1175  while (computer_.Compute(allow_partial_minibatch));
1176 
1177  // ... then wait on tasks_ready_semaphore_.
1179  if (is_finished_) {
1180  allow_partial_minibatch = true;
1181  while (computer_.Compute(allow_partial_minibatch));
1182  return;
1183  }
1184  }
1185 }
bool Compute(bool allow_partial_minibatch)
Does some kind of computation, choosing the highest-priority thing to compute.
void Wait()
decrease the counter

◆ ComputeFunc()

static void ComputeFunc ( NnetBatchInference object)
inlinestaticprivate

Definition at line 565 of file nnet-batch-compute.h.

Referenced by NnetBatchInference::NnetBatchInference().

565 { object->Compute(); }

◆ Finished()

void Finished ( )

The user should call this after the last input has been provided via AcceptInput().

This will force the last utterances to be flushed out (to be retrieved by GetOutput()), rather than waiting until the relevant minibatches are full.

Definition at line 1165 of file nnet-batch-compute.cc.

References NnetBatchInference::is_finished_, Semaphore::Signal(), and NnetBatchInference::tasks_ready_semaphore_.

Referenced by main().

1165  {
1166  is_finished_ = true;
1168 }
void Signal()
increase the counter

◆ GetOutput()

bool GetOutput ( std::string *  utterance_id,
Matrix< BaseFloat > *  output 
)

The user should call this to obtain output.

It's guaranteed to be in the same order as the input was provided, but it may be delayed. 'output' will be the output of the neural net, spliced together over the chunks (and with acoustic scaling applied if it was specified in the options; the subtraction of priors will depend whether you supplied a non-empty vector of priors to the constructor.

This call does not block (i.e. does not wait on any semaphores) unless you have previously called Finished(). It returns true if it actually got any output; if none was ready it will return false.

Definition at line 1130 of file nnet-batch-compute.cc.

References NnetBatchInference::is_finished_, kaldi::nnet3::MergeTaskOutput(), NnetBatchInference::UtteranceInfo::num_tasks_finished, NnetBatchInference::UtteranceInfo::tasks, Semaphore::TryWait(), NnetBatchInference::UtteranceInfo::utterance_id, NnetBatchInference::utts_, and Semaphore::Wait().

Referenced by main().

1131  {
1132  if (utts_.empty())
1133  return false;
1134 
1135  UtteranceInfo *info = *utts_.begin();
1136  std::vector<NnetInferenceTask> &tasks = info->tasks;
1137  int32 num_tasks = tasks.size();
1138  for (; info->num_tasks_finished < num_tasks; ++info->num_tasks_finished) {
1139  Semaphore &semaphore = tasks[info->num_tasks_finished].semaphore;
1140  if (is_finished_) {
1141  semaphore.Wait();
1142  } else {
1143  if (!semaphore.TryWait()) {
1144  // If not all of the tasks of this utterance are ready yet,
1145  // just return false.
1146  return false;
1147  }
1148  }
1149  }
1150  MergeTaskOutput(tasks, output);
1151  *utterance_id = info->utterance_id;
1152  delete info;
1153  utts_.pop_front();
1154  return true;
1155 }
kaldi::int32 int32
std::list< UtteranceInfo * > utts_
void MergeTaskOutput(const std::vector< NnetInferenceTask > &tasks, Matrix< BaseFloat > *output)
Merges together the &#39;output_cpu&#39; (if the &#39;output_to_cpu&#39; members are true) or the &#39;output&#39; members of...

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( NnetBatchInference  )
private

Member Data Documentation

◆ compute_thread_

std::thread compute_thread_
private

◆ computer_

NnetBatchComputer computer_
private

◆ is_finished_

◆ tasks_ready_semaphore_

Semaphore tasks_ready_semaphore_
private

◆ utterance_counter_

int32 utterance_counter_
private

Definition at line 599 of file nnet-batch-compute.h.

Referenced by NnetBatchInference::AcceptInput().

◆ utts_


The documentation for this class was generated from the following files: