nnet-optimize.h
Go to the documentation of this file.
1 // nnet3/nnet-optimize.h
2 
3 // Copyright 2015-2016 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_NNET_OPTIMIZE_H_
22 #define KALDI_NNET3_NNET_OPTIMIZE_H_
23 
24 #include "nnet3/nnet-compile.h"
25 #include "nnet3/nnet-analyze.h"
27 
28 namespace kaldi {
29 namespace nnet3 {
30 
31 // Options class for optimizing a NnetComputation. The main projected use for
32 // this is in debugging the optimization code itself, so that if an error is
33 // detected, we can work out which optimization was responsible for the error.
34 // See the Register() function below for option-specific documentation.
36  // Caution: if adding or removing members, the Read and Write functions and
37  // the == operator should be modified. This relates to computation caching.
38  bool optimize; // setting this false disallow all optimization.
57  // optimize_looped_computation is a 'hidden config' not available from
58  // the command line; it's set to true to enable the optimization for
59  // looped computation that turns a linear computation into a loop.
61 
63  optimize(true),
64  consolidate_model_update(true),
65  propagate_in_place(true),
66  backprop_in_place(true),
67  optimize_row_ops(true),
68  split_row_ops(true),
69  extend_matrices(true),
70  convert_addition(true),
71  remove_assignments(true),
72  allow_left_merge(true),
73  allow_right_merge(true),
74  initialize_undefined(true),
75  move_sizing_commands(true),
76  allocate_from_other(true),
77  min_deriv_time(std::numeric_limits<int32>::min()),
78  max_deriv_time(std::numeric_limits<int32>::max()),
79  max_deriv_time_relative(std::numeric_limits<int32>::max()),
80  snip_row_ops(true),
81  memory_compression_level(1),
82  optimize_looped_computation(false) { }
83 
84  void Register(OptionsItf *opts) {
85  opts->Register("optimize", &optimize, "Set this to false to turn off all "
86  "optimizations");
87  opts->Register("consolidate-model-update", &consolidate_model_update,
88  "Set to false to disable optimization that consolidates "
89  "the model-update phase of backprop (e.g. for recurrent "
90  "architectures");
91  opts->Register("propagate-in-place", &propagate_in_place, "Set to false to "
92  "disable optimization that allows in-place propagation");
93  opts->Register("backprop-in-place", &backprop_in_place, "Set to false to "
94  "disable optimization that allows in-place backprop");
95  opts->Register("extend-matrices", &extend_matrices, "This optimization "
96  "can reduce memory requirements for TDNNs when applied "
97  "together with --convert-addition=true");
98  opts->Register("optimize-row-ops", &optimize_row_ops, "Set to false to "
99  "disable certain optimizations that act on operations of "
100  "type *Row*.");
101  opts->Register("split-row-ops", &split_row_ops, "Set to false to disable "
102  "an optimization that may replace some operations of type "
103  "kCopyRowsMulti or kAddRowsMulti with up to two simpler "
104  "operations.");
105  opts->Register("convert-addition", &convert_addition, "Set to false to "
106  "disable the optimization that converts Add commands into "
107  "Copy commands wherever possible.");
108  opts->Register("remove-assignments", &remove_assignments, "Set to false to "
109  "disable optimization that removes redundant assignments");
110  opts->Register("allow-left-merge", &allow_left_merge, "Set to false to "
111  "disable left-merging of variables in remove-assignments "
112  "(obscure option)");
113  opts->Register("allow-right-merge", &allow_right_merge, "Set to false to "
114  "disable right-merging of variables in remove-assignments "
115  "(obscure option)");
116  opts->Register("initialize-undefined", &initialize_undefined, "Set to false "
117  "to disable optimization that avoids redundant zeroing");
118  opts->Register("move-sizing-commands", &move_sizing_commands, "Set to false "
119  "to disable optimization that moves matrix allocation and "
120  "deallocation commands to conserve memory.");
121  opts->Register("allocate-from-other", &allocate_from_other, "Instead of "
122  "deleting a matrix of a given size and then allocating "
123  "a matrix of the same size, allow re-use of that memory");
124  opts->Register("min-deriv-time", &min_deriv_time, "You can set this to "
125  "the minimum t value that you want derivatives to be computed "
126  "at when updating the model. This is an optimization that "
127  "saves time in the backprop phase for recurrent frameworks");
128  opts->Register("max-deriv-time", &max_deriv_time, "You can set this to "
129  "the maximum t value that you want derivatives to be computed "
130  "at when updating the model. This is an optimization that "
131  "saves time in the backprop phase for recurrent frameworks");
132  opts->Register("max-deriv-time-relative", &max_deriv_time_relative,
133  "An alternative mechanism for setting the --max-deriv-time, "
134  "suitable for situations where the length of the egs is "
135  "variable. If set, it is equivalent to setting the "
136  "--max-deriv-time to this value plus the largest 't' value "
137  "in any 'output' node of the computation request.");
138  opts->Register("snip-row-ops", &snip_row_ops, "Set this to false to "
139  "disable an optimization that reduces the size of certain "
140  "per-row operations");
141  opts->Register("memory-compression-level", &memory_compression_level,
142  "This is only relevant to training, not decoding. Set this "
143  "to 0,1,2; higher levels are more aggressive at reducing "
144  "memory by compressing quantities needed for backprop, "
145  "potentially at the expense of speed and the accuracy "
146  "of derivatives. 0 means no compression at all; 1 means "
147  "compression that shouldn't affect results at all.");
148 
149  }
150  void Read(std::istream &is, bool binary);
151  void Write(std::ostream &os, bool binary) const;
152  bool operator == (const NnetOptimizeOptions &other) const;
153 };
154 
155 
156 /* This utility function, used in code that calls LimitDerivativeTimes() (and
157  required in code that calls Optimize(), returns the largest time
158  't' in any of the 'outputs' in the computation request, or crashes if there
159  are no outputs (or no cindexes in those outputs). */
161 
162 
185 void Optimize(const NnetOptimizeOptions &config,
186  const Nnet &nnet,
187  int32 max_output_time_in_request,
188  NnetComputation *computation);
189 
190 
191 
195 
197  use_shortcut(true),
198  cache_capacity(64) { }
199 
200  void Register(OptionsItf *opts) {
201  opts->Register("use-shortcut", &use_shortcut,
202  "If true, use the 'shortcut' in compilation whereby "
203  "computation requests with regular structure are identified "
204  "as such, a computation with a smaller number of distinct "
205  "values of 'n' is compiled (e.g. 2), and the compiled "
206  "computation is expanded to match the size of the real "
207  "computation request.");
208  opts->Register("cache-capacity", &cache_capacity,
209  "Determines how many computations the computation-cache will "
210  "store (most-recently-used).");
211  }
212 };
213 
220  public:
221  CachingOptimizingCompiler(const Nnet &nnet,
222  const CachingOptimizingCompilerOptions config =
224 
226  CachingOptimizingCompiler(const Nnet &nnet,
227  const NnetOptimizeOptions &opt_config,
228  const CachingOptimizingCompilerOptions config =
230 
232 
240  std::shared_ptr<const NnetComputation> Compile(
241  const ComputationRequest &request);
242  void ReadCache(std::istream &is, bool binary);
243  void WriteCache(std::ostream &os, bool binary);
244 
245 
246  // GetSimpleNnetContext() is equivalent to calling:
247  // ComputeSimpleNnetContext(nnet_, &nnet_left_context,
248  // &nnet_right_context)
249  // but it caches it inside this class. This functionality is independent of
250  // the rest of the functionality of this class; it just happens to be a
251  // convenient place to put this mechanism.
252  void GetSimpleNnetContext(int32 *nnet_left_context,
253  int32 *nnet_right_context);
254 
255  private:
256 
257  // This function just implements the work of Compile(); it's made a separate
258  // function for the convenience of the timer code, to avoid it being called
259  // twice (we also call this function directly from inside the class).
260  std::shared_ptr<const NnetComputation> CompileInternal(const ComputationRequest &request);
261 
262  // This function, called from CompileInternal(), is called when a
263  // ComputationRequest has been determined not to have already been cached. It
264  // otherwise has the same interface as CompileInternal(), but assumes that
265  // there is nothing cached for this computation as yet. It compiles the
266  // computation and takes care of caching it.
267  std::shared_ptr<const NnetComputation> CompileAndCache(const ComputationRequest &request);
268 
269 
270  // This function, called from CompileInternal(), tries to compile the
271  // ComputationRequest 'request' via 'shortcut' compilation; if this is
272  // possible, it returns a pointer to a newly allocated computation that it has
273  // compiled this way (note: this computation will not yet have been placed in
274  // the computation cache). If this is not possible for some reason
275  // (e.g. shortcut compilation is disabled in the config; or the computation
276  // request was not decomposable because of too few n values or irregular or
277  // unexpected structure), this function returns NULL and you should compile
278  // via CompileNoShortcut.
279  const NnetComputation *CompileViaShortcut(const ComputationRequest &request);
280 
281  // This function, called from CompileInternal(), tries to compile the
282  // ComputationRequest 'request' via the regular (not shortcut) compilation
283  // process; it returns a pointer to a newly allocated computation that it has
284  // compiled this way (note: this computation will not yet have been placed in
285  // the computation cache).
286  const NnetComputation *CompileNoShortcut(const ComputationRequest &request);
287 
288  const Nnet &nnet_;
291 
292 
293  // seconds spent in various phases of compilation-- for diagnostic messages
301 
303 
304  // These following two variables are only used by the function GetSimpleNnetContext().
307 };
308 
309 
340 void LimitDerivativeTimes(const Nnet &nnet,
341  const ComputationRequest &request,
342  const NnetOptimizeOptions &opts,
343  NnetComputation *computation);
344 
349 void ConsolidateModelUpdate(const Nnet &nnet,
350  NnetComputation *computation);
351 
355 void ConvertAdditionToAssignment(const Nnet &nnet,
356  NnetComputation *computation);
357 
358 
361  const Nnet &nnet,
362  NnetComputation *computation);
363 
364 
368 void RemoveUnnecessaryZeroing(const Nnet &nnet, NnetComputation *computation);
369 
370 
373 void MoveSizingCommands(const Nnet &nnet, NnetComputation *computation);
374 
378 void RemoveUnnecessaryAllocation(const Nnet &nnet,
379  NnetComputation *computation);
380 
381 
390 void ConsolidateIoOperations(const Nnet &nnet,
391  NnetComputation *computation);
392 
393 
394 
395 } // namespace nnet3
396 } // namespace kaldi
397 
398 
399 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: nnet-optimize.h:84
void Read(std::istream &is, bool binary)
void ConsolidateIoOperations(const Nnet &nnet, NnetComputation *computation)
This optimization puts the input operations (kAcceptInput) and output operations (kProvideOutput) at ...
void Write(std::ostream &os, bool binary) const
ArpaLmCompiler * Compile(bool seps, const std::string &infile)
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void VariableMergingOptimization(const NnetOptimizeOptions &config, const Nnet &nnet, NnetComputation *computation)
This wraps class VariableMergingOptimizer in a simplified interface.
This file contains utilities for analyzing and checking computations, which are used in the optimizat...
void ConvertAdditionToAssignment(const Nnet &nnet, NnetComputation *computation)
This converts addition operations (things with Add in their names) to copy operations (things with Co...
kaldi::int32 int32
void LimitDerivativeTimes(const Nnet &nnet, int32 min_deriv_time, int32 max_deriv_time, NnetComputation *computation)
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
void MoveSizingCommands(const Nnet &nnet, NnetComputation *computation)
This optimization moves commands that allocate and zero matrices to as late as possible, and moves commands that deallocate matrices to as early as possible.
bool operator==(const NnetOptimizeOptions &other) const
int32 MaxOutputTimeInRequest(const ComputationRequest &request)
void RemoveUnnecessaryAllocation(const Nnet &nnet, NnetComputation *computation)
This optimization detects cases where we deallocate a matrix, and then later allocate another matrix ...
void RemoveUnnecessaryZeroing(const Nnet &nnet, NnetComputation *computation)
This optimization function removes, where possible, commands of type type kSetConst.
void Optimize(const NnetOptimizeOptions &config, const Nnet &nnet, int32 max_output_time_in_request, NnetComputation *computation)
This is the top-level function for optimizing a computation.
void ConsolidateModelUpdate(const Nnet &nnet, NnetComputation *computation)
This optimization consolidates the model-update part of backprop commands, for components in (e...
Class ComputationCache is used inside class CachingOptimizingCompiler to cache previously computed co...
CachingOptimizingCompilerOptions config_