nnet-batch-compute.h
Go to the documentation of this file.
1 // nnet3/nnet-batch-compute.h
2 
3 // Copyright 2012-2018 Johns Hopkins University (author: Daniel Povey)
4 // 2018 Hang Lyu
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_NNET_BATCH_COMPUTE_H_
22 #define KALDI_NNET3_NNET_BATCH_COMPUTE_H_
23 
24 #include <vector>
25 #include <string>
26 #include <list>
27 #include <utility>
28 #include <condition_variable>
29 #include "base/kaldi-common.h"
30 #include "gmm/am-diag-gmm.h"
31 #include "hmm/transition-model.h"
32 #include "itf/decodable-itf.h"
33 #include "nnet3/nnet-optimize.h"
34 #include "nnet3/nnet-compute.h"
35 #include "nnet3/am-nnet-simple.h"
38 #include "util/stl-utils.h"
39 
40 
41 namespace kaldi {
42 namespace nnet3 {
43 
44 
51  // The copy constructor is required to exist because of std::vector's resize()
52  // function, but in practice should never be used.
54  KALDI_ERR << "NnetInferenceTask was not designed to be copied.";
55  }
57 
58 
59  // The input frames, which are treated as being numbered t=0, t=1, etc. (If
60  // the lowest t value was originally nonzero in the 'natural' numbering, this
61  // just means we conceptually shift the 't' values; the only real constraint
62  // is that the 't' values are contiguous.
64 
65  // The index of the first output frame (in the shifted numbering where the
66  // first output frame is numbered zero. This will typically be less than one,
67  // because most network topologies require left context. If this was an
68  // 'interior' chunk of a recurrent topology like LSTMs, first_input_t may be
69  // substantially less than zero, due to 'extra_left_context'.
71 
72  // The stride of output 't' values: e.g., will be 1 for normal-frame-rate
73  // models, and 3 for low-frame-rate models such as chain models.
75 
76  // The number of output 't' values (they will start from zero and be separated
77  // by output_t_stride). This will be the num-rows of 'output'.
79 
80  // 'num_initial_unused_output_frames', which will normally be zero, is the
81  // number of rows of the output matrix ('output' or 'output_cpu') which won't
82  // actually be needed by the user, usually because they overlap with a
83  // previous chunk. This can happen because the number of outputs isn't a
84  // multiple of the number of chunks.
86 
87  // 0 < num_used_output_frames <= num_output_frames - num_initial_unused_output_frames
88  // is the number of output frames which are actually going to be used by the
89  // user. (Due to edge effects, not all are necessarily used).
91 
92  // first_used_output_frame_index is provided for the convenience of the user
93  // so that they can know how this chunk relates to the utterance which it is
94  // a part of.
95  // It represents an output frame index in the original utterance-- after
96  // subsampling; so not a 't' value but a 't' value divided by
97  // frame-subsampling-factor. Specifically, it tells you the row index in the
98  // full utterance's output which corresponds to the first 'used' frame index
99  // at the output of this chunk, specifically: the row numbered
100  // 'num_initial_unused_output_frames' in the 'output' or 'output_cpu' data
101  // member.
103 
104  // True if this chunk is an 'edge' (the beginning or end of an utterance) AND
105  // is structurally different somehow from non-edge chunk, e.g. requires less
106  // context. This is present only so that NnetBatchComputer will know the
107  // appropriate minibatch size to use.
108  bool is_edge;
109 
110  // True if this task represents an irregular-sized chunk. These can happen
111  // only for utterances that are shorter than the requested minibatch size, and
112  // it should be quite rare. We use a minibatch size of 1 in this case.
114 
115  // The i-vector for this chunk, if this network accepts i-vector inputs.
117 
118  // A priority (higher is more urgent); may be either sign. May be updated
119  // after this object is provided to class NnetBatchComputer.
120  double priority;
121 
122  // This semaphore will be incremented by class NnetBatchComputer when this
123  // chunk is done. After this semaphore is incremented, class
124  // NnetBatchComputer will no longer hold any pointers to this class.
126 
127  // Will be set to true by the caller if they want the output of the neural net
128  // to be copied to CPU (to 'output'). If false, the output will stay on
129  // the GPU (if used)- in cu_output.
131 
132  // The neural net output, of dimension num_output_frames by the output-dim of
133  // the neural net, will be written to 'output_cpu' if 'output_to_cpu' is true.
134  // This is expected to be empty when this task is provided to class
135  // NnetBatchComputer, and will be nonempty (if output_to_cpu == true) when the
136  // task is completed and the semaphore is signaled.
138 
139  // The output goes here instead of 'output_to_cpu' is false.
141 };
142 
143 
149 
150  NnetBatchComputerOptions(): minibatch_size(128),
151  edge_minibatch_size(32),
152  ensure_exact_final_context(false),
153  partial_minibatch_factor(0.5) {
154  }
155 
156  void Register(OptionsItf *po) {
158  po->Register("minibatch-size", &minibatch_size, "Number of chunks per "
159  "minibatch (see also edge-minibatch-size)");
160  po->Register("edge-minibatch-size", &edge_minibatch_size, "Number of "
161  "chunks per minibatch: this applies to chunks at the "
162  "beginnings and ends of utterances, in cases (such as "
163  "recurrent models) when the computation would be different "
164  "from the usual one.");
165  po->Register("ensure-exact-final-context", &ensure_exact_final_context,
166  "If true, for utterances shorter than --frames-per-chunk, "
167  "use exact-length, special computations. If false, "
168  "pad with repeats of the last frame. Would only affect "
169  "the output for backwards-recurrent models, but would "
170  "negatively impact speed in all cases.");
171  po->Register("partial-minibatch-factor", &partial_minibatch_factor,
172  "Factor that controls how small partial minibatches will be "
173  "they become necessary. We will potentially do the computation "
174  "for sizes: int(partial_minibatch_factor^n * minibatch_size "
175  ", for n = 0, 1, 2.... Set it to 0.0 if you want to use "
176  "only the specified minibatch sizes.");
177  }
178 };
179 
180 
193 void MergeTaskOutput(
194  const std::vector<NnetInferenceTask> &tasks,
196 void MergeTaskOutput(
197  const std::vector<NnetInferenceTask> &tasks,
198  CuMatrix<BaseFloat> *output);
199 
208  public:
220  const Nnet &nnet,
221  const VectorBase<BaseFloat> &priors);
222 
223 
229  void AcceptTask(NnetInferenceTask *task,
230  int32 max_minibatches_full = -1);
231 
233  int32 NumFullPendingMinibatches() const { return num_full_minibatches_; }
234 
235 
245  bool Compute(bool allow_partial_minibatch);
246 
247 
265  void SplitUtteranceIntoTasks(
266  bool output_to_cpu,
267  const Matrix<BaseFloat> &input,
268  const Vector<BaseFloat> *ivector,
269  const Matrix<BaseFloat> *online_ivectors,
270  int32 online_ivector_period,
271  std::vector<NnetInferenceTask> *tasks);
272  void SplitUtteranceIntoTasks(
273  bool output_to_cpu,
274  const CuMatrix<BaseFloat> &input,
275  const CuVector<BaseFloat> *ivector,
276  const CuMatrix<BaseFloat> *online_ivectors,
277  int32 online_ivector_period,
278  std::vector<NnetInferenceTask> *tasks);
279 
280  const NnetBatchComputerOptions &GetOptions() { return opts_; }
281 
283 
284  private:
286 
287  // Information about a specific minibatch size for a group of tasks sharing a
288  // specific structure (in terms of left and right context, etc.)
290  // the computation for this minibatch size.
291  std::shared_ptr<const NnetComputation> computation;
292  int32 num_done; // The number of minibatches computed: for diagnostics.
293  int64 tot_num_tasks; // The total number of tasks in those minibatches,
294  // also for diagnostics... can be used to compute
295  // how 'full', on average, these minibatches were.
296  double seconds_taken; // The total time elapsed in computation for this
297  // minibatch type.
298  MinibatchSizeInfo(): computation(NULL), num_done(0),
299  tot_num_tasks(0), seconds_taken(0.0) { }
300  };
301 
302 
303  // A computation group is a group of tasks that have the same structure
304  // (number of input and output frames, left and right context).
306  // The tasks to be completed. This array is added-to by AcceptTask(),
307  // and removed-from by GetHighestPriorityComputation(), which is called
308  // from Compute().
309  std::vector<NnetInferenceTask*> tasks;
310 
311  // Map from minibatch-size to information specific to this minibatch-size,
312  // including the NnetComputation. This is set up by
313  // GetHighestPriorityComputation(), which is called from Compute().
314  std::map<int32, MinibatchSizeInfo> minibatch_info;
315  };
316 
317  // This struct allows us to arrange the tasks into groups that can be
318  // computed in the same minibatch.
321  num_input_frames(task.input.NumRows()),
324 
325  bool operator == (const ComputationGroupKey &other) const {
326  return num_input_frames == other.num_input_frames &&
327  first_input_t == other.first_input_t &&
329  }
333  };
334 
336  int32 operator () (const ComputationGroupKey &key) const {
337  return key.num_input_frames + 18043 * key.first_input_t +
338  6413 * key.num_output_frames;
339  }
340  };
341 
342 
343  typedef unordered_map<ComputationGroupKey, ComputationGroupInfo,
345 
346  // Gets the priority for a group, higher means higher priority. (A group is a
347  // list of tasks that may be computed in the same minibatch). What this
348  // function does is a kind of heuristic.
349  // If allow_partial_minibatch == false, it will set the priority for
350  // any minibatches that are not full to negative infinity.
351  inline double GetPriority(bool allow_partial_minibatch,
352  const ComputationGroupInfo &info) const;
353 
354  // Returns the minibatch size for this group of tasks, i.e. the size of a full
355  // minibatch for this type of task, which is what we'd ideally like to
356  // compute. Note: the is_edge and is_irregular options should be the same
357  // for for all tasks in the group.
358  // - If 'tasks' is empty or info.is_edge and info.is_irregular are both,
359  // false, then return opts_.minibatch_size
360  // - If 'tasks' is nonempty and tasks[0].is_irregular is true, then
361  // returns 1.
362  // - If 'tasks' is nonempty and tasks[0].is_irregular is false and
363  // tasks[0].is_edge is true, then returns opts_.edge_minibatch_size.
364  inline int32 GetMinibatchSize(const ComputationGroupInfo &info) const;
365 
366 
367  // This function compiles, and returns, a computation for tasks of
368  // the structure present in info.tasks[0], and the specified minibatch
369  // size.
370  std::shared_ptr<const NnetComputation> GetComputation(
371  const ComputationGroupInfo &info,
372  int32 minibatch_size);
373 
374 
375  // Returns the actual minibatch size we'll use for this computation. In most
376  // cases it will be opts_.minibatch_size (or opts_.edge_minibatch_size if
377  // appropriate; but if the number of available tasks is much less than the
378  // appropriate minibatch size, it may be less. The minibatch size may be
379  // greater than info.tasks.size(); in that case, the remaining 'n' values in
380  // the minibatch are not used. (It may also be less than info.tasks.size(),
381  // in which case we only do some of them).
382  int32 GetActualMinibatchSize(const ComputationGroupInfo &info) const;
383 
384 
385  // This function gets the highest-priority 'num_tasks' tasks from 'info',
386  // removes them from the array info->tasks, and puts them into the array
387  // 'tasks' (which is assumed to be initially empty).
388  // This function also updates the num_full_minibatches_ variable if
389  // necessary, and takes care of notifying any related condition variables.
390  void GetHighestPriorityTasks(
391  int32 num_tasks,
392  ComputationGroupInfo *info,
393  std::vector<NnetInferenceTask*> *tasks);
394 
417  MinibatchSizeInfo *GetHighestPriorityComputation(
418  bool allow_partial_minibatch,
419  int32 *minibatch_size,
420  std::vector<NnetInferenceTask*> *tasks);
421 
434  void FormatInputs(int32 minibatch_size,
435  const std::vector<NnetInferenceTask*> &tasks,
436  CuMatrix<BaseFloat> *input,
437  CuMatrix<BaseFloat> *ivector);
438 
439 
440  // Copies 'output', piece by piece, to the 'output_cpu' or 'output'
441  // members of 'tasks', depending on their 'output_to_cpu' value.
442  void FormatOutputs(const CuMatrix<BaseFloat> &output,
443  const std::vector<NnetInferenceTask*> &tasks);
444 
445 
446  // Changes opts_.frames_per_chunk to be a multiple of
447  // opts_.frame_subsampling_factor, if needed.
448  void CheckAndFixConfigs();
449 
450  // this function creates and returns the computation request which is to be
451  // compiled.
452  static void GetComputationRequest(const NnetInferenceTask &task,
453  int32 minibatch_size,
454  ComputationRequest *request);
455 
456  // Prints some logging information about what we computed, with breakdown by
457  // minibatch type.
458  void PrintMinibatchStats();
459 
461  const Nnet &nnet_;
464 
465  // Mutex that guards this object. It is only held for fairly quick operations
466  // (not while the actual computation is being done).
467  std::mutex mutex_;
468 
469  // tasks_ contains all the queued tasks.
470  // Each key contains a vector of NnetInferenceTask* pointers, of the same
471  // structure (i.e., IsCompatible() returns true).
472  MapType tasks_;
473 
474  // num_full_minibatches_ is a function of the data in tasks_ (and the
475  // minibatch sizes, specified in opts_. It is the number of full minibatches
476  // of tasks that are pending, meaning: for each group of tasks, the number of
477  // pending tasks divided by the minibatch-size for that group in integer
478  // arithmetic. This is kept updated for thread synchronization reasons, because
479  // it is the shared variable
481 
482  // a map from 'n' to a condition variable corresponding to the condition:
483  // num_full_minibatches_ <= n. Any time the number of full minibatches drops
484  // below n, the corresponding condition variable is notified (if it exists).
485  std::unordered_map<int32, std::condition_variable*> no_more_than_n_minibatches_full_;
486 
487  // some static information about the neural net, computed at the start.
493 };
494 
495 
503  public:
504 
506  const NnetBatchComputerOptions &opts,
507  const Nnet &nnet,
508  const VectorBase<BaseFloat> &priors);
509 
527  void AcceptInput(const std::string &utterance_id,
528  const Matrix<BaseFloat> &input,
529  const Vector<BaseFloat> *ivector,
530  const Matrix<BaseFloat> *online_ivectors,
531  int32 online_ivector_period);
532 
539  void Finished();
540 
554  bool GetOutput(std::string *utterance_id,
555  Matrix<BaseFloat> *output);
556 
558  private:
560 
561  // This is the computation thread, which is run in the background. It will
562  // exit once the user calls Finished() and all computation is completed.
563  void Compute();
564  // static wrapper for Compute().
565  static void ComputeFunc(NnetBatchInference *object) { object->Compute(); }
566 
567 
568  // This object implements the internals of what this class does. It is
569  // accessed both by the main thread (from where AcceptInput(), Finished() and
570  // GetOutput() are called), and from the background thread in which Compute()
571  // is called.
573 
574  // This is set to true when the user calls Finished(); the computation thread
575  // sees it and knows to flush
577 
578  // This semaphore is signaled by the main thread (the thread in which
579  // AcceptInput() is called) every time a new utterance is added, and waited on
580  // in the background thread in which Compute() is called.
582 
583  struct UtteranceInfo {
584  std::string utterance_id;
585  // The tasks into which we split this utterance.
586  std::vector<NnetInferenceTask> tasks;
587  // 'num_tasks_finished' is the number of tasks which are known to be
588  // finished, meaning we successfully waited for those tasks' 'semaphore'
589  // member. When this reaches tasks.size(), we are ready to consolidate
590  // the output into a single matrix and return it to the user.
592  };
593 
594  // This list is only accessed directly by the main thread, by AcceptInput()
595  // and GetOutput(). It is a list of utterances, with more recently added ones
596  // at the back. When utterances are given to the user by GetOutput(),
597  std::list<UtteranceInfo*> utts_;
598 
599  int32 utterance_counter_; // counter that increases on every utterance.
600 
601  // The thread running the Compute() process.
602  std::thread compute_thread_;
603 };
604 
605 
614  public:
636  NnetBatchDecoder(const fst::Fst<fst::StdArc> &fst,
637  const LatticeFasterDecoderConfig &decoder_config,
638  const TransitionModel &trans_model,
639  const fst::SymbolTable *word_syms,
640  bool allow_partial,
641  int32 num_threads,
642  NnetBatchComputer *computer);
643 
662  void AcceptInput(const std::string &utterance_id,
663  const Matrix<BaseFloat> &input,
664  const Vector<BaseFloat> *ivector,
665  const Matrix<BaseFloat> *online_ivectors,
666  int32 online_ivector_period);
667 
668  /*
669  The user should call this function each time there was a problem with an utterance
670  prior to being able to call AcceptInput()-- e.g. missing i-vectors. This will
671  update the num-failed-utterances stats which are stored in this class.
672  */
673  void UtteranceFailed();
674 
675  /*
676  The user should call this when all input has been provided, e.g.
677  when AcceptInput will not be called any more. It will block until
678  all threads have terminated; after that, you can call GetOutput()
679  until it returns false, which will guarantee that nothing remains
680  to compute.
681  It returns the number of utterances that have been successfully decoded.
682  */
683  int32 Finished();
684 
709  bool GetOutput(std::string *utterance_id,
710  CompactLattice *clat,
711  std::string *sentence);
712 
713  // This version of GetOutput is for where config.determinize_lattice == false
714  // (w.r.t. the config provided to the constructor). It is the same as the
715  // other version except it outputs to a normal Lattice, not a CompactLattice.
716  bool GetOutput(std::string *utterance_id,
717  Lattice *lat,
718  std::string *sentence);
719 
720  ~NnetBatchDecoder();
721 
722  private:
724 
725  struct UtteranceInput {
726  std::string utterance_id;
731  };
732 
733  // This object is created when a thread finished an utterance. For utterances
734  // where decoding failed somehow, the relevant lattice (compact_lat, if
735  // opts_.determinize == true, or lat otherwise) will be empty (have no
736  // states).
738  std::string utterance_id;
739  bool finished;
742  std::string sentence; // 'sentence' is only nonempty if a non-NULL symbol
743  // table was provided to the constructor of class
744  // NnetBatchDecoder; it's the sentence as a string (a
745  // sequence of words separated by space). It's used
746  // for printing the sentence to stderr, which we do
747  // in the main thread to keep the order consistent.
748  };
749 
750  // This is the decoding thread, several copies of which are run in the
751  // background. It will exit once the user calls Finished() and all
752  // computation is completed.
753  void Decode();
754  // static wrapper for Compute().
755  static void DecodeFunc(NnetBatchDecoder *object) { object->Decode(); }
756 
757  // This is the computation thread; it handles the neural net inference.
758  void Compute();
759  // static wrapper for Compute().
760  static void ComputeFunc(NnetBatchDecoder *object) { object->Compute(); }
761 
762 
763  // Sets the priorities of the tasks in a newly provided utterance.
764  void SetPriorities(std::vector<NnetInferenceTask> *tasks);
765 
766  // In the single-thread case, this sets priority_offset_ to 'priority'.
767  // In the multi-threaded case it causes priority_offset_ to approach
768  // 'priority' at a rate that depends on the nunber of threads.
769  void UpdatePriorityOffset(double priority);
770 
771  // This function does the determinization (if needed) and finds the best path through
772  // the lattice to update the stats. It is expected that when it is called, 'output' must
773  // have its 'lat' member set up.
774  void ProcessOutputUtterance(UtteranceOutput *output);
775 
776  const fst::Fst<fst::StdArc> &fst_;
779  const fst::SymbolTable *word_syms_; // May be NULL. Owned here.
782  std::vector<std::thread*> decode_threads_;
783  std::thread compute_thread_; // Thread that calls computer_->Compute().
784 
785 
786  // 'input_utterance', together with utterance_ready_semaphore_ and
787  // utterance_consumed_semaphore_, use used to 'hand off' information about a
788  // newly provided utterance from AcceptInput() to a decoder thread that is
789  // ready to process a new utterance.
791  Semaphore input_ready_semaphore_; // Is signaled by the main thread when
792  // AcceptInput() is called and a new
793  // utterance is being provided (or when the
794  // input is finished), and waited on in
795  // decoder thread.
796  Semaphore input_consumed_semaphore_; // Is signaled in decoder thread when it
797  // has finished consuming the input, so
798  // the main thread can know when it
799  // should continue (to avoid letting
800  // 'input' go out of scope while it's
801  // still needed).
802 
803  Semaphore tasks_ready_semaphore_; // Is signaled when new tasks are added to
804  // the computer_ object (or when we're finished).
805 
806  bool is_finished_; // True if the input is finished. If this is true, a
807  // signal to input_ready_semaphore_ indicates to the
808  // decoder thread that it should terminate.
809 
810  bool tasks_finished_; // True if we know that no more tasks will be given
811  // to the computer_ object.
812 
813 
814  // pending_utts_ is a list of utterances that have been provided via
815  // AcceptInput(), but their decoding has not yet finished. AcceptInput() will
816  // push_back to it, and GetOutput() will pop_front(). When a decoding thread
817  // has finished an utterance it will set its 'finished' member to true. There
818  // is no need to synchronize or use mutexes here.
819  std::list<UtteranceOutput*> pending_utts_;
820 
821  // priority_offset_ is something used in determining the priorities of nnet
822  // computation tasks. It starts off at zero and becomes more negative with
823  // time, with the aim being that the priority of the first task (i.e. the
824  // leftmost chunk) of a new utterance should be at about the same priority as
825  // whatever chunks we are just now getting around to decoding.
827 
828  // Some statistics accumulated by this class, for logging and timing purposes.
829  double tot_like_; // Total likelihood (of best path) over all lattices that
830  // we output.
831  int64 frame_count_; // Frame count over all latices that we output.
832  int32 num_success_; // Number of successfully decoded files.
833  int32 num_fail_; // Number of files where decoding failed.
834  int32 num_partial_; // Number of files that were successfully decoded but
835  // reached no final-state (can only be nonzero if
836  // allow_partial_ is true).
837  std::mutex stats_mutex_; // Mutex that guards the statistics from tot_like_
838  // through num_partial_.
839  Timer timer_; // Timer used to print real-time info.
840 };
841 
842 
843 } // namespace nnet3
844 } // namespace kaldi
845 
846 #endif // KALDI_NNET3_NNET_BATCH_COMPUTE_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Decoder object that uses multiple CPU threads for the graph search, plus a GPU for the neural net inf...
static void ComputeFunc(NnetBatchInference *object)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void GetOutput(OnlineFeatureInterface *a, Matrix< BaseFloat > *output)
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
unordered_map< ComputationGroupKey, ComputationGroupInfo, ComputationGroupKeyHasher > MapType
static void ComputeFunc(NnetBatchDecoder *object)
const fst::SymbolTable * word_syms_
const fst::Fst< fst::StdArc > & fst_
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
std::map< int32, MinibatchSizeInfo > minibatch_info
This class implements a simplified interface to class NnetBatchComputer, which is suitable for progra...
class NnetInferenceTask represents a chunk of an utterance that is requested to be computed...
std::shared_ptr< const NnetComputation > computation
static void DecodeFunc(NnetBatchDecoder *object)
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
const TransitionModel & trans_model_
int32 NumFullPendingMinibatches() const
Returns the number of full minibatches waiting to be computed.
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
std::list< UtteranceOutput * > pending_utts_
NnetInferenceTask(const NnetInferenceTask &other)
A class representing a vector.
Definition: kaldi-vector.h:406
CachingOptimizingCompiler compiler_
bool operator==(const LatticeWeightTpl< FloatType > &wa, const LatticeWeightTpl< FloatType > &wb)
NnetBatchComputerOptions opts_
std::unordered_map< int32, std::condition_variable * > no_more_than_n_minibatches_full_
const NnetBatchComputerOptions & GetOptions()
std::list< UtteranceInfo * > utts_
This class does neural net inference in a way that is optimized for GPU use: it combines chunks of mu...
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
const LatticeFasterDecoderConfig & decoder_opts_
void MergeTaskOutput(const std::vector< NnetInferenceTask > &tasks, Matrix< BaseFloat > *output)
Merges together the &#39;output_cpu&#39; (if the &#39;output_to_cpu&#39; members are true) or the &#39;output&#39; members of...
std::vector< std::thread * > decode_threads_
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...