nnet-optimize-utils.h
Go to the documentation of this file.
1 // nnet3/nnet-optimize-utils.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_OPTIMIZE_UTILS_H_
21 #define KALDI_NNET3_NNET_OPTIMIZE_UTILS_H_
22 
23 #include <mutex>
24 #include <list>
25 #include "nnet3/nnet-compile.h"
26 #include "nnet3/nnet-analyze.h"
27 
28 
29 namespace kaldi {
30 namespace nnet3 {
31 
32 
33 struct NnetOptimizeOptions; // Forward declaration.
34 
134  public:
136  const Nnet &nnet,
137  NnetComputation *computation);
138  // Note: you can call this only once. If it returns true, it means it has
139  // merged variables. In this case, you have the option to instantiate another
140  // copy of the class and try again with that other copy.
141  bool MergeVariables();
142 
143  private:
159  std::pair<bool,bool> MayBeMerged(int32 command, int32 s1, int32 s2) const;
160 
161  // Merges to matrices, whether left merge or right merge. s_to_keep and
162  // s_to_discard are the submatrix-indexes we will keep and discard
163  // respectively (these are s1 and s2 in some order.
164  void DoMerge(int32 command_index, int32 s_to_keep, int32 m_to_discard);
165 
167  void MarkAsDirty(int32 s);
168 
169  void Initialize();
170 
172  const Nnet &nnet_;
174 
176 
177  // lists of submatrices that correspond to each matrix.
178  std::vector<std::vector<int32> > matrix_to_submatrix_;
179 
180  // for each variable (as defined by analyzer_.variables), true if
181  // we have already performed a merge on it.
182  std::vector<bool> variable_dirty_;
183 
185 };
186 
196 void ExtendMatrices(NnetComputation *computation);
197 
198 
208 void ConsolidateModelUpdate(const Nnet &nnet,
209  NnetComputation *computation);
210 
211 
212 
213 
214 // Class DerivativeTimeLimiter is used inside LimitDerivativeTimes().
215 // Its function is to modify the computation so that we don't work
216 // with derivatives outside of a specified range of t values; this is
217 // useful, for instance, in BLSTMs where you might have a fair amount of
218 // left and right context in the training examples but don't want to
219 // propagate the derivatives to there.
220 //
221 // We require that the computation have debug info set up
222 // (!matrix_debug_info.empty()) and that this be the first
223 // optimization you perform. This means that the debug_info will
224 // be accurate and that all matrices will be initialized with
225 // zero contents.
227  public:
228  DerivativeTimeLimiter(const Nnet &nnet,
229  int32 min_deriv_time,
230  int32 max_deriv_time,
231  NnetComputation *computation);
232 
233  void LimitDerivTimes();
234 
235  private:
236 
237  // sets up matrix_prune_info_.
238  void ComputeMatrixPruneInfo();
239 
240  // sets up subatrix_map_ and submatrix_map_if_deriv_.
241  void ComputeSubmatrixMaps();
242 
243  // modifies all the commands as appropriate to reflect that some derivative
244  // values are zero (i.e. save any computation we can, based on this
245  // assumption).
246  void ModifyCommands();
247 
248  // this function, called after we've modified the commands to operate on
249  // submatrices of the original matrices, works out for which of the matrices
250  // we can actually limit their extent in time, and changes the way the
251  // matrices are allocated (it may remove some matrices entirely).
252  void PruneMatrices();
253 
254  // this function modifies commands of type kPropagate to set the memo indexes
255  // to zero if the memo indexes appear in the list memos_to_delete_. It's
256  // because if a backprop command has been deleted, the propagate command
257  // should no longer store a memo.
258  void RemoveUnusedMemos();
259 
260 
261  // called from PruneMatrices only for matrices that are derivatives,
262  // not inputs or outputs of the computation, and which are partly
263  // inside the time range, this function returns true if we can
264  // limit the size of the matrix (because variables outside the
265  // desired range are never accessed), and false otherwise.
266  inline bool CanLimitMatrix(const Analyzer &analyzer,
267  int32 matrix_index) const;
268 
269  // called from PruneMatrices after it has figured out which matrices we need
270  // to limit to a row-range, this function changes computation->submatrices and
271  // computation->matrices in the way required to do that.
272  inline void LimitMatrices(const std::vector<bool> &will_limit);
273 
274  // does the processing for a command of type kMatrixCopy or kMatrixAdd.
275  void MapSimpleMatrixCommand(NnetComputation::Command *c);
276 
277  // does the processing for a command of type kCopyRows or kAddRows, where
278  // 1st and 2nd args are submatrix indexes and the 3rd arg is a vector of
279  // row-indexes.
280  void MapIndexesCommand(NnetComputation::Command *c);
281 
282  // does the processing for a command of type kAddRowsMulti, kAddToRowsMulti,
283  // kCopyRowsMulti or kCopyToRowsMulti, 1st arg is submatrix index that the
284  // command is called with, and 2nd arg is 'indexes_multi' index (which
285  // contains pairs (source-submatrix, source-row).
286  void MapIndexesMultiCommand(NnetComputation::Command *c);
287 
288  // does the processing for a command of type kAddRowRanges.
289  void MapAddRowRangesCommand(NnetComputation::Command *c);
290 
291  // Modifies this command to take into account prune_info_. At this point we
292  // don't actually reduce the size of the matrices, we simply make the commands
293  // operate on submatrices of the original matrices where possible- or
294  // delete them completely if their output is all zeros or for other reasons
295  // we detect that they would be no-ops.
296  // Note: this calls computation_->NewSubMatrix, and will generate duplicates
297  // of the same submatrix which we'll later remove in RemoveOrphanMatrices.
298  void ModifyCommand(NnetComputation::Command *command);
299 
300  // this will detect which matrices we can reduce the allocated size of,
301  // and reduce their size.
302  void ResizeMatrices();
303 
304  // Requires that we have mapped 'initial_submatrix' to 'new_submatrix' in
305  // an operation that may have removed some data on the left and/or the
306  // right (but still they point to the same underlying matrix). Outputs
307  // to 'left_prune' and 'right_prune' the number of rows we have
308  // removed on the left and on the right respectively.
309  inline void GetPruneValues(int32 initial_submatrix,
310  int32 new_submatrix,
311  int32 *left_prune,
312  int32 *right_prune) const;
313 
314  // This helper function, used while mapping commands, returns true if the
315  // Cindex represented by the pair (submatrix, row_index) has a 't' value
316  // within the range [min_deriv_time_, max_deriv_time_].
317  bool RowIsKept(int32 submatrix,
318  int32 row_index) const;
319 
320 
322  bool is_deriv; // true if the matrix represents a derivative (copied from
323  // the debug-info; repeated here for convenience).
324  bool fully_inside_range; // True if the matrix is completely inside the time range
325  // specified.
326  bool partly_inside_range; // true if the matrix is partly (but not fully)
327  // inside the time range specified.
328  int32 row_begin; // if partly_inside_range, the first row that's within the time range (i.e. for which
329  // min_deriv_time_ <= t < max_deriv_time_.
330  int32 row_end; // if partly_inside_range, one plus the last row that's within
331  // the time range.
332  };
333 
334 
335  const Nnet &nnet_;
336 
339 
340  // the computation; we require it to have debug info set up
341  // (otherwise you shouldn't be instantiating this class).
343 
344  // for each matrix index > 0, the index of a submatrix that consists of
345  // the entirety of that matrix.
346  std::vector<int32> whole_submatrices_;
347 
348  std::vector<MatrixPruneInfo> matrix_prune_info_;
349 
350  // for each submatrix in the original range of computation_->submatrices,
351  // submatrix_map_ maps it to itself if the submatrix is completely inside the
352  // time-range, or to zero if it's completely outside the time-range, or to a
353  // newly created submatrix-index if it's partly inside the time-range.
354  std::vector<int32> submatrix_map_;
355 
356  // submatrix_map_if_deriv_ contains the quantity:
357  // IsDerivative(s) ? submatrix_map_[s] : s,
358  // where IsDerivative(s) is true if s is part of a matrix that (according to its
359  // debug info) represents a derivative.
360  // this comes up so frequently that storing it separately seemed like a good idea.
361  std::vector<int32> submatrix_map_if_deriv_;
362 
363  std::vector<MatrixPruneInfo> prune_info_;
364 
365  // List of indexes of memos that will no longer be stored because the backprop
366  // commands using them were deleted.
367  std::unordered_set<int32> memos_to_delete_;
368 };
369 
370 
371 // This utility function, used in code that calls LimitDerivativeTimes(), returns
372 // the largest time 't' in any of the 'outputs' in the computation request,
373 // or crashes if there are no outputs (or no cindexes in those outputs).
375 
376 // This is the top-level interface to limit the times on which derivatives are
377 // computed (e.g. for truncated BPTT); internally it uses class
378 // DerivativeLimiter. Will do nothing if min_deriv_time and max_deriv_time are
379 // their default -inf,+inf values.
380 void LimitDerivativeTimes(const Nnet &nnet,
381  int32 min_deriv_time,
382  int32 max_deriv_time,
383  NnetComputation *computation);
384 
402 bool RequestIsDecomposable(const ComputationRequest &request,
403  ComputationRequest *mini_request,
404  int32 *num_n_values);
405 
406 
428 void ExpandComputation(const Nnet &nnet,
429  const MiscComputationInfo &misc_info,
430  const NnetComputation &computation,
431  bool need_debug_info,
432  int32 num_n_values,
433  NnetComputation *expanded_computation);
434 
435 
436 
445 bool ReplaceRowWithMatrixOps(NnetComputation *computation);
446 
459 bool SnipRowOps(NnetComputation *computation);
460 
461 
476 bool SplitRowOps(NnetComputation *computation);
477 
484 void RenumberComputation(NnetComputation *computation);
485 
486 
488 void RemoveNoOps(NnetComputation *computation);
489 
495  std::vector<int32*> *submatrix_args);
496 
501 bool MatrixIsUnused(const Analyzer &analyzer,
502  const NnetComputation &computation,
503  int32 m);
504 
511 void RemoveCommandsForUnusedMatrix(const Analyzer &analyzer,
512  int32 m,
513  NnetComputation *computation);
514 
515 
520 void IdentifySubmatrixArgs(std::vector<NnetComputation::Command> *commands,
521  std::vector<int32*> *submatrix_args);
522 
530  std::vector<int32*> *submatrix_args);
531 
532 
536 void IdentifyIndexesMultiArgs(std::vector<NnetComputation::Command> *commands,
537  std::vector<int32*> *indexes_multi_args);
538 
542 void IdentifyIndexesArgs(std::vector<NnetComputation::Command> *commands,
543  std::vector<int32*> *indexes_args);
544 
548 void IdentifyIndexesArgs(std::vector<NnetComputation::Command> *commands,
549  std::vector<int32*> *indexes_args);
550 
554 void IdentifyIndexesRangesArgs(std::vector<NnetComputation::Command> *commands,
555  std::vector<int32*> *indexes_ranges_args);
556 
570 void InsertCommands(
571  std::vector<std::pair<int32, NnetComputation::Command> > *commands,
572  NnetComputation *computation);
573 
592 void OptimizeMemoryCompression(const Nnet &nnet,
593  int32 memory_compression_level,
594  NnetComputation *computation);
595 
596 
609 void OptimizeLoopedComputation(const Nnet &nnet,
610  NnetComputation *computation);
611 
612 
616 void FixGotoLabel(NnetComputation *computation);
617 
618 
626  public:
627  ComputationCache(int32 cache_capacity);
628 
629  // Note: if something fails in Read(), or the written cache was from an older
630  // format, it will just leave the cache empty.
631  void Read(std::istream &is, bool binary);
632 
633  void Write(std::ostream &os, bool binary) const;
634 
635 
636  // Searches for the computation corresponding to this computation, and returns
637  // it if cached, or NULL (as std::shared_ptr) if not. (We need shared_ptr to
638  // handle multi-threaded operation, so that if the computation is ejected from
639  // the cache by another thread, it won't be deleted while still in use). This
640  // function also moves this computation to the end of the
641  // most-recently-accessed queue, which is why it's not const.
642  std::shared_ptr<const NnetComputation> Find(const ComputationRequest &request);
643 
644 
645  // Inserts the computation into the cache-- this is assumed to be the
646  // computation for the computation-request 'request'. Returns a shared_ptr
647  // which can be used to access the object. This function takes ownership of
648  // 'computation'.
649  std::shared_ptr<const NnetComputation> Insert(const ComputationRequest &request,
650  const NnetComputation *computation);
651 
652  ~ComputationCache();
653 
654  // Checks the stored computation for correctness.
655  void Check(const Nnet &nnet) const;
656  private:
657 
658  std::mutex mutex_; // Read/write mutex.
659 
661 
662  // The access queue for keeping track of the freshness of computation.
663  // Most-recently-accessed computation is at the end, and
664  // least-recently-accessed computaiton is at the beginning. Together with
665  // computation_cache_, this forms a most-recently-used (MRU) cache for
666  // Computations, indexed by ComputationRequest. The pointers are owned in
667  // computation_cache_.
668  typedef std::list<const ComputationRequest*> AqType;
670 
671  // Map from computation-request to pair of (computation, and position in
672  // access_queue_). Used for fast lookup of previously compiled computations.
673  // All pointers are owned here.
674  typedef unordered_map<const ComputationRequest*,
675  std::pair<std::shared_ptr<const NnetComputation>, AqType::iterator>,
679 };
680 
681 
682 
683 
684 } // namespace nnet3
685 } // namespace kaldi
686 
687 
688 #endif
bool MatrixIsUnused(const Analyzer &analyzer, const NnetComputation &computation, int32 m)
This function returns true if matrix 1 <= m < computation->matrices.size() is unused, defined as: it is not an input or an output, and is not accessed other than via commands of type kAllocMatrix, kDeallocMatrix, and kSetConst.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This class is responsible for merging matrices, although you probably want to access it via the the f...
std::vector< MatrixPruneInfo > prune_info_
bool SplitRowOps(NnetComputation *computation)
This function detects cases where commands of type kAddRowsMulti, kAddToRowsMulti, kCopyRowsMulti, kCopyToRowsMulti use indexes that correspond to at most two submatrices, in two distinct ranges without gaps filled by -1&#39;s, and could be converted to at most two commands of type kMatrixAdd, kMatrixCopy, kAddRows or kCopyRows.
void OptimizeLoopedComputation(const Nnet &nnet, NnetComputation *computation)
This function tries to optimize computation &#39;computation&#39; for an &#39;looped&#39; computation.
void IdentifySubmatrixArgs(NnetComputation::Command *c, std::vector< int32 *> *submatrix_args)
This function outputs to "submatrix_args" the addresses of a subset of arguments arg1 through arg6 in...
void RenumberComputation(NnetComputation *computation)
This function detects submatrices and matrices that are never used (e.g.
void InsertCommands(std::vector< std::pair< int32, NnetComputation::Command > > *new_commands, NnetComputation *computation)
Inserts commands into the computation at the requested places.
void IdentifySubmatrixArgsInComputation(NnetComputation *computation, std::vector< int32 *> *submatrix_args)
This function outputs to "submatrix_args" the addresses of integers in &#39;computation&#39; that correspond ...
This file contains utilities for analyzing and checking computations, which are used in the optimizat...
bool RequestIsDecomposable(const ComputationRequest &request, ComputationRequest *mini_request, int32 *num_n_values)
This function, used in &#39;shortcut&#39; compilation where we first compile a smaller computation with the s...
kaldi::int32 int32
void ExtendMatrices(NnetComputation *computation)
This is not really an optimization in itself but it can make things easier for class VariableMergingO...
void LimitDerivativeTimes(const Nnet &nnet, int32 min_deriv_time, int32 max_deriv_time, NnetComputation *computation)
void OptimizeMemoryCompression(const Nnet &nnet, int32 memory_compression_level, NnetComputation *computation)
Performs optimization to reduce memory usage where possible, making use of the kCompressMatrix and kD...
void DoMerge(int32 command_index, int32 s_to_keep, int32 m_to_discard)
int32 MaxOutputTimeInRequest(const ComputationRequest &request)
std::vector< MatrixPruneInfo > matrix_prune_info_
std::vector< std::vector< int32 > > matrix_to_submatrix_
void IdentifyIndexesMultiArgs(std::vector< NnetComputation::Command > *commands, std::vector< int32 *> *indexes_multi_args)
Identifies in the vector of commands, arguments that correspond to indexes into the computation&#39;s ind...
std::unordered_set< int32 > memos_to_delete_
bool ReplaceRowWithMatrixOps(NnetComputation *computation)
This function detects cases where commands of type kCopyRows, kAddRows or kAddToRows can be converted...
void FixGotoLabel(NnetComputation *computation)
This function ensures that the arg1 of a final command of type kGotoLabel is the same as the command ...
unordered_map< const ComputationRequest *, std::pair< std::shared_ptr< const NnetComputation >, AqType::iterator >, ComputationRequestHasher, ComputationRequestPtrEqual > CacheType
void ExpandComputation(const Nnet &nnet, const MiscComputationInfo &misc_info, const NnetComputation &computation, bool need_debug_info, int32 num_n_values, NnetComputation *expanded_computation)
This function is used in &#39;shortcut&#39; compilation to expand a computation that has been compiled for ex...
void ConsolidateModelUpdate(const Nnet &nnet, NnetComputation *computation)
This optimization consolidates the model-update part of backprop commands, for components in (e...
void MarkAsDirty(int32 s)
Marks the variables underlying submatrix &#39;s&#39; as dirty.
VariableMergingOptimizer(const NnetOptimizeOptions &config, const Nnet &nnet, NnetComputation *computation)
void RemoveNoOps(NnetComputation *computation)
Removes commands of type kNoOperation in the computation.
void IdentifyIndexesRangesArgs(std::vector< NnetComputation::Command > *commands, std::vector< int32 *> *indexes_ranges_args)
Identifies in the vector of commands, arguments that correspond to indexes into the computation&#39;s &#39;in...
void RemoveCommandsForUnusedMatrix(const Analyzer &analyzer, int32 m, NnetComputation *computation)
This function removes from &#39;computation&#39; the commands accessing matrix &#39;m&#39;, which is assumed to be un...
std::list< const ComputationRequest * > AqType
bool SnipRowOps(NnetComputation *computation)
This function detects cases where commands of type kCopyRows, kAddRows, kAddRowsMulti, kAddToRowsMulti, kCopyRowsMulti, kCopyToRowsMulti or kAddRowRanges use indexes that start or end with -1&#39;s or equivalents, and replace them with similar commands that act on a sub-matrix of the matrices they are currently acting on.
Class ComputationCache is used inside class CachingOptimizingCompiler to cache previously computed co...
void IdentifyIndexesArgs(std::vector< NnetComputation::Command > *commands, std::vector< int32 *> *indexes_args)
Identifies in the vector of commands, arguments that correspond to indexes into the computation&#39;s &#39;in...
This struct exists to set up various pieces of analysis; it helps avoid the repetition of code where ...
Definition: nnet-analyze.h:294
std::pair< bool, bool > MayBeMerged(int32 command, int32 s1, int32 s2) const
This function returns a pair of bools saying whether we can do a (left and/or right) merge respective...