online-nnet2-decoding-threaded.h
Go to the documentation of this file.
1 // online2/online-nnet2-decoding-threaded.h
2 
3 // Copyright 2014-2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #ifndef KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
22 #define KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
23 
24 #include <string>
25 #include <vector>
26 #include <deque>
27 #include <mutex>
28 #include <thread>
29 
30 #include "matrix/matrix-lib.h"
31 #include "util/common-utils.h"
32 #include "base/kaldi-error.h"
34 #include "nnet2/am-nnet.h"
38 #include "hmm/transition-model.h"
39 #include "util/kaldi-semaphore.h"
40 
41 namespace kaldi {
44 
45 
66  public:
68 
69  // Most calls to this class should provide the thread-type of the caller,
70  // producing or consuming. Actually the behavior of this class is symmetric
71  // between the two types of thread.
73 
74  // All functions returning bool will return true normally, and false if
75  // SetAbort() was set; if they return false, you should probably call SetAbort()
76  // on any other ThreadSynchronizer classes you are using and then return from
77  // the thread.
78 
79  // call this to lock the object being guarded.
80  bool Lock(ThreadType t);
81 
82  // Call this to unlock the object being guarded, if you don't want the next call to
83  // Lock to stall.
84  bool UnlockSuccess(ThreadType t);
85 
86  // Call this if you want the next call to Lock() to stall until the other
87  // (producer/consumer) thread has locked and then unlocked the mutex. Note
88  // that, if the other thread then calls Lock and then UnlockFailure, this will
89  // generate a printed warning (and if repeated too many times, an exception).
90  bool UnlockFailure(ThreadType t);
91 
92  // Sets abort_ flag so future calls will return false, and future calls to
93  // Lock() won't lock the mutex but will immediately return false.
94  void SetAbort();
95 
97 
98  private:
99  bool abort_;
100  bool producer_waiting_; // true if producer is/will be waiting on semaphore
101  bool consumer_waiting_; // true if consumer is/will be waiting on semaphore
102  std::mutex mutex_; // Locks the buffer object.
103  ThreadType held_by_; // Record of which thread is holding the mutex (if
104  // held); else undefined. Used for validation of input.
105  Semaphore producer_semaphore_; // The producer thread waits on this semaphore
106  Semaphore consumer_semaphore_; // The consumer thread waits on this semaphore
107  int32 num_errors_; // Rumber of times the threads alternated doing Lock() and
108  // UnlockFailure(). This should not happen at all; but
109  // it's more user-friendly to simply warn a few times; and then
110  // only after a number of errors, to fail.
112 };
113 
114 
115 
116 
117 // This is the configuration class for SingleUtteranceNnet2DecoderThreaded. The
118 // actual command line program requires other configs that it creates
119 // separately, and which are not included here: namely,
120 // OnlineNnet2FeaturePipelineConfig and OnlineEndpointConfig.
122 
124 
126 
127  int32 max_buffered_features; // maximum frames of features we allow to be
128  // held in the feature buffer before we block
129  // the feature-processing thread.
130 
131  int32 feature_batch_size; // maximum number of frames at a time that we decode
132  // before unlocking the mutex. The only real cost
133  // here is a mutex lock/unlock, so it's OK to make
134  // this fairly small.
135  int32 max_loglikes_copy; // maximum unused frames of log-likelihoods we will
136  // copy from the decodable object back into another
137  // matrix to be supplied to the decodable object.
138  // make this too large-> will block the
139  // decoder-search thread while copying; too small
140  // -> the nnet-evaluation thread may get blocked
141  // for too long while waiting for the decodable
142  // thread to be ready.
143  int32 nnet_batch_size; // batch size (number of frames) we evaluate in the
144  // neural net, if this many is available. To take
145  // best advantage of BLAS, you may want to set this
146  // fairly large, e.g. 32 or 64 frames. It probably
147  // makes sense to tune this a bit.
148  int32 decode_batch_size; // maximum number of frames at a time that we decode
149  // before unlocking the mutex. The only real cost
150  // here is a mutex lock/unlock, so it's OK to make
151  // this fairly small.
152 
154  acoustic_scale = 0.1;
155  max_buffered_features = 100;
156  feature_batch_size = 2;
157  nnet_batch_size = 32;
158  max_loglikes_copy = 20;
159  decode_batch_size = 2;
160  }
161 
162  void Check();
163 
164  void Register(OptionsItf *opts) {
165  decoder_opts.Register(opts);
166  opts->Register("acoustic-scale", &acoustic_scale, "Scale used on acoustics "
167  "when decoding");
168  opts->Register("max-buffered-features", &max_buffered_features, "Obscure "
169  "setting, affects multi-threaded decoding.");
170  opts->Register("feature-batch-size", &max_buffered_features, "Obscure "
171  "setting, affects multi-threaded decoding.");
172  opts->Register("nnet-batch-size", &nnet_batch_size, "Maximum batch size "
173  "(in frames) used when evaluating neural net likelihoods");
174  opts->Register("max-loglikes-copy", &max_loglikes_copy, "Obscure "
175  "setting, affects multi-threaded decoding.");
176  opts->Register("decode-batch-sie", &decode_batch_size, "Obscure "
177  "setting, affects multi-threaded decoding.");
178  }
179 };
180 
191  public:
192  // Constructor. Unlike SingleUtteranceNnet2Decoder, we create the
193  // feature_pipeline object inside this class, since access to it needs to be
194  // controlled by a mutex and this class knows how to handle that. The
195  // feature_info and adaptation_state arguments are used to initialize the
196  // (locally owned) feature pipeline.
198  const OnlineNnet2DecodingThreadedConfig &config,
199  const TransitionModel &tmodel,
200  const nnet2::AmNnet &am_nnet,
201  const fst::Fst<fst::StdArc> &fst,
202  const OnlineNnet2FeaturePipelineInfo &feature_info,
203  const OnlineIvectorExtractorAdaptationState &adaptation_state,
204  const OnlineCmvnState &cmvn_state);
205 
206 
207 
210  void AcceptWaveform(BaseFloat samp_freq,
211  const VectorBase<BaseFloat> &wave_part);
212 
216  int32 NumWaveformPiecesPending();
217 
222  void InputFinished();
223 
230  void TerminateDecoding();
231 
235  void Wait();
236 
242  void FinalizeDecoding();
243 
249  int32 NumFramesReceivedApprox() const;
250 
255  int32 NumFramesDecoded() const;
256 
269  void GetLattice(bool end_of_utterance,
270  CompactLattice *clat,
271  BaseFloat *final_relative_cost) const;
272 
282  void GetBestPath(bool end_of_utterance,
283  Lattice *best_path,
284  BaseFloat *final_relative_cost) const;
285 
288  bool EndpointDetected(const OnlineEndpointConfig &config);
289 
296  void GetAdaptationState(OnlineIvectorExtractorAdaptationState *adaptation_state);
297 
304  void GetCmvnState(OnlineCmvnState *cmvn_state);
305 
310  BaseFloat GetRemainingWaveform(Vector<BaseFloat> *waveform_out) const;
311 
313  private:
314 
315  // This function will instruct all threads to abort operation as soon as they
316  // can safely do so, by calling SetAbort() in the threads
317  void AbortAllThreads(bool error);
318 
319  // This function waits for all the threads that have been spawned. It is
320  // called in the destructor and Wait(). If called twice it is not an error.
321  void WaitForAllThreads();
322 
323 
324 
325  // this function runs the thread that does the feature extraction and
326  // neural-net evaluation. In case of failure, calls
327  // me->AbortAllThreads(true).
328  static void RunNnetEvaluation(SingleUtteranceNnet2DecoderThreaded *me);
329  // member-function version of RunNnetEvaluation, called by RunNnetEvaluation.
330  bool RunNnetEvaluationInternal();
331  // the following function is called inside RunNnetEvaluationInternal(); it
332  // takes the log and subtracts the prior.
333  void ProcessLoglikes(const CuVector<BaseFloat> &log_inv_prior,
334  CuMatrixBase<BaseFloat> *loglikes);
335  // called from RunNnetEvaluationInternal(). Returns true in the normal case,
336  // false on error; if it returns false, then we expect that the calling thread
337  // will terminate. This assumes the caller has already
338  // locked feature_pipeline_mutex_.
339  bool FeatureComputation(int32 num_frames_output);
340 
341 
342  // this function runs the thread that does the neural-net evaluation.
343  // In case of failure, calls me->AbortAllThreads(true).
344  static void RunDecoderSearch(SingleUtteranceNnet2DecoderThreaded *me);
345  // member-function version of RunDecoderSearch, called by RunDecoderSearch.
346  bool RunDecoderSearchInternal();
347 
348 
349  // Member variables:
350 
352 
354 
356 
357 
358  // sampling_rate_ is set the first time AcceptWaveform is called.
360  // A record of how many samples have been provided so
361  // far via calls to AcceptWaveform.
363 
364  // The next two variables are written to by AcceptWaveform from the main
365  // thread, and read by the feature-processing thread; they are guarded by
366  // waveform_synchronizer_. There is no bound on the buffer size here.
367  // Later-arriving data is appended to the vector. When InputFinished() is
368  // called from the main thread, the main thread sets input_finished_ = true.
369  // sampling_rate_ is only needed for checking that it matches the config.
371  std::deque< Vector<BaseFloat>* > input_waveform_;
372 
373 
375 
376  // feature_pipeline_ is accessed by the nnet-evaluation thread, by the main
377  // thread if GetAdaptionState() is called, and by the decoding thread via
378  // ComputeCurrentTraceback() if online silence weighting is being used. It is
379  // guarded by feature_pipeline_mutex_.
382 
383  // The next two variables are required only for implementation of the function
384  // GetRemainingWaveform(). After we take waveform from the input_waveform_
385  // queue to be processed into features, we put them onto this deque. Then we
386  // discard from this queue any that we can discard because we have already
387  // decoded those frames (see num_frames_decoded_), and we increment
388  // num_samples_discarded_ by the corresponding number of samples.
389  std::deque< Vector<BaseFloat>* > processed_waveform_;
391 
392  // This object is used to control the (optional) downweighting of silence in iVector estimation,
393  // which is based on the decoder traceback.
396 
397 
398  // this Decodable object just stores a matrix of scaled log-likelihoods
399  // obtained by the nnet-evaluation thread. It is produced by the
400  // nnet-evaluation thread and consumed by the decoder-search thread. The
401  // decoding thread sets num_frames_decoded_ so the nnet-evaluation thread
402  // knows which frames of log-likelihoods it can discard. Both of these
403  // variables are guarded by decodable_synchronizer_. Note:
404  // the num_frames_decoded_ may be less than the current number of frames
405  // the decoder has decoded; the decoder thread sets this variable when it
406  // locks this mutex.
410 
411  // the decoder_ object contains everything related to the graph search.
413  // decoder_mutex_ guards the decoder_ object. It is usually held by the decoding
414  // thread (where it is released and re-obtained on each frame), but is obtained
415  // by the main (parent) thread if you call functions like NumFramesDecoded(),
416  // GetLattice() and GetBestPath().
417  mutable std::mutex decoder_mutex_; // declared as mutable because we mutate
418  // this mutex in const methods
419 
420  // This contains the thread pointers for the nnet-evaluation and
421  // decoder-search threads respectively (or NULL if they have been joined in
422  // Wait()).
423  std::thread threads_[2];
424 
425  // This is set to true if AbortAllThreads was called for any reason, including
426  // if someone called TerminateDecoding().
427  bool abort_;
428 
429  // This is set to true if any kind of unexpected error is encountered,
430  // including if exceptions are raised in any of the threads. Will normally
431  // be a coding error, malloc failure-- something we should never encounter.
432  bool error_;
433 
434 };
435 
436 
438 
439 } // namespace kaldi
440 
441 
442 
443 #endif // KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
KALDI_DISALLOW_COPY_AND_ASSIGN(ThreadSynchronizer)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
kaldi::int32 int32
bool EndpointDetected(const OnlineEndpointConfig &config, int32 num_frames_decoded, int32 trailing_silence_frames, BaseFloat frame_shift_in_seconds, BaseFloat final_relative_cost)
This function returns true if this set of endpointing rules thinks we should terminate decoding...
This decodable class returns log-likes stored in a matrix; it supports repeatedly writing to the matr...
This file contains a different version of the feature-extraction pipeline in online-feature-pipeline...
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
class ThreadSynchronizer acts to guard an arbitrary type of buffer between a producing and a consumin...
std::deque< Vector< BaseFloat > *> processed_waveform_
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
Matrix for CUDA computing.
Definition: matrix-common.h:69
A class representing a vector.
Definition: kaldi-vector.h:406
OnlineNnet2FeaturePipeline is a class that&#39;s responsible for putting together the various parts of th...
Provides a vector abstraction class.
Definition: kaldi-vector.h:41