Optimization in the "nnet3" setup

Introduction

This page covers the code-optimization process in the "nnet3" setup, in which we modify the sequence of commands stored in the NnetComputation object in order to make the execution more efficient.

Overview of optimization

The optimization process is something that happens after compilation. It consists of modifying the NnetComputation to make it more efficient. From the point of view of the user it is just a single function call:

void Optimize(const NnetOptimizeConfig &config,
              const Nnet &nnet,
              const ComputationRequest &request,
              NnetComputation *computation);

Internally, this performs various different types of optimizations, which we will go through below. We also discuss code analysis, which is used in the optimization code to help figure out which changes to the code are permissible; and code checking.

This page is organized as:

Code analysis

As mentioned, we have in nnet-analyze.h various utilities for analyzing code. We defined for a convenience a struct Analyzer which runs all this analysis on some compiled code:

struct Analyzer {
  ComputationVariables variables;
  std::vector<CommandAttributes> command_attributes;
  std::vector<std::vector<Access> > variable_accesses;
  std::vector<MatrixAccesses> matrix_accesses;
  void Init(const Nnet &nnet, const NnetComputation &computation);
};

The Init function sets up its members:

void Analyzer::Init(const Nnet &nnet, const NnetComputation &computation) {
  variables.Init(computation);
  ComputeCommandAttributes(nnet, computation, variables, &command_attributes);
  ComputeVariableAccesses(variables, command_attributes, &variable_accesses);
  ComputeMatrixAccesses(nnet, computation, variables, command_attributes,
                        &matrix_accesses);
}

There are really four function calls here, because the constructor of member "variables" sets up the member of type ComputationVariables.

Computation variables

In order to analyze the computation, it is helpful to break it down into actions on individual variables. If efficiency were not an issue, we could do the analysis at the level of individual elements of a matrix, declaring these to be the variables. However, for now we do the analysis at a slightly coarser-grained level, consisting of column-ranges of matrices. We choose the coarsest set of row ranges such that the column-ranges of all the submatrices can be expressed exactly as a union of these column ranges. Class ComputationVariables is responsible for identifying these column ranges and for getting the set of variables associated with a given matrix or sub-matrix.

Note that it is possible that in the future we may decide to do the analysis at a finer level (e.g. individual row of the current variables) which would enable more complete optimization in certain rather specialized circumstances. This would not involve very extensive changes to the code outside of class ComputationVariables.

A "variable" is a zero-based index that corresponds to one of the variables identified by class ComputationVariables. The public interface of this class is below:

class ComputationVariables {
 public:
  void Init(const NnetComputation &computation);

  // This function updates the CommandAttributes object to record an access of
  // type read, write or read-write on the variables that this sub-matrix
  // corresponds to, and also updates the matrices_accessed variable by adding
  // the number of the underlying matrix.
  void RecordAccessForSubmatrix(
      int32 submatrix_index,
      AccessType access_type,
      CommandAttributes *ca) const;
  // Appends to variables_indexes the list of variables corresponding to a
  // matrix index.
  void AppendVariablesForMatrix(
      int32 matrix_index,
      std::vector<int32> *variable_indexes) const;
  int32 NumVariables() const { return num_variables_; }
  int32 GetMatrixForVariable(int32 variable) const;

 private:
   ...
};

The RecordAccessForSubmatrix() function won't be very self-explanatory because we haven't yet introduced struct CommandAttributes. We'll say more about it below.

Command attributes

Struct CommandAttributes records which variables are read and which variables written, and also which matrices are read and written.

struct CommandAttributes {
  // variables read
  std::vector<int32> variables_read;
  // variables written
  std::vector<int32> variables_written;

  // matrices read
  std::vector<int32> matrices_read;
  // matrices written
  std::vector<int32> matrices_written;

  // true if this command has side effects e.g. on the model (such as
  // Backprop on an updatable component, or StoreStats).
  bool has_side_effects;
  CommandAttributes(): has_side_effects(false) { }
};

Some operations must be considered read/write instead of just read or write. For instance, adding something to a matrix is a read/write operation because the final result depends on what was there previously. In these cases we add a variable (or matrix) to both read and written lists. In addition, a pure-write operation that accesses only some parts of a variable or matrix must be considered a read/write operation on that variable or matrix, because the final value still depends on the contents at the start.

The function ComputationVariables::RecordAccessForSubmatrix() is responsible for updating the CommandAttributes variable for commands; it is declared as follows.

  void RecordAccessForSubmatrix(
      int32 submatrix_index,
      AccessType access_type,
      CommandAttributes *ca) const;

where AccessType is an enum that can take values kReadAccess, kWriteAccess and kReadWriteAccess.

Computing the command attributes

After initializing the ComputationVariables object, the next stage in analysis of a computation is to obtain a vector of CommandAttributes, one for each command in the computation. The function ComputeCommandAttributes() is responsible for this. This function is mostly a big switch statement, and we show the first part of it in order to give the reader some idea what is going on:

void ComputeCommandAttributes(
    const Nnet &nnet,
    const NnetComputation &computation,
    const ComputationVariables &vars,
    std::vector<CommandAttributes> *attributes) {
  int32 num_commands = computation.commands.size();
  attributes->clear();
  attributes->resize(num_commands);
  for (int32 command_index = 0; command_index < num_commands; command_index++) {
    const NnetComputation::Command &c = computation.commands[command_index];
    CommandAttributes &attr = (*attributes)[command_index];
    switch (c.command_type) {
      case NnetComputation::kAllocMatrixZeroed:
        vars.AppendVariablesForMatrix(c.arg1, &attr.variables_written);
        break;
      case NnetComputation::kAllocMatrixUndefined: // nothing is written here.
        break;
      case NnetComputation::kDeallocMatrix: // ditto.
        break;
      case NnetComputation::kPropagate:
        vars.RecordAccessForSubmatrix(c.arg3, kReadAccess, &attr);
        if (nnet.GetComponent(c.arg1)->Properties() & kPropagateAdds)
          vars.RecordAccessForSubmatrix(c.arg4, kReadWriteAccess, &attr);
        else
          vars.RecordAccessForSubmatrix(c.arg4, kWriteAccess, &attr);
        break;
        ...

Computing the variable accesses

The next stage in analysis is to compute the variable accesses. This takes the information we stored in the CommandAttributes, and list it per variable We define a struct Access as:

struct Access {
  int32 command_index;
  AccessType access_type;
};

where AccessType is an enumeration value mentioned above. The accesses to any variable will be stored as a std::vector<Access>, and function that computes the accesses for all variables is declared as follows:

void ComputeVariableAccesses(
    const ComputationVariables &variables,
    const std::vector<CommandAttributes> &command_attributes,
    std::vector<std::vector<Access> > *variable_accesses);

The output variable_accesses is a vector of length (the number of variables), and then a list sorted by command index. There will be only one access per command, as we consolidate them- for example, a combination of a read and write access would be consolidated into a single kReadWrite access.

Computing the matrix accesses

Struct MatrixAccesses stores all the information we record for a single matrix, relating to how it is allocated and accessed:

struct MatrixAccesses {
  // Index of the command that allocates the matrix, or -1 if the command
  // doesn't exist (e.g. it is an input).
  int32 allocate_command;
  // Index of the command that deallocates the matrix, or -1 if never gets
  // deallocated (e.g. it is an output).
  int32 deallocate_command;
  // Records the indexes of commands that access the matrix, and the type
  // (read, read/write, write).  It will be sorted on command index with only
  // one record per command.  Note: a write to only a part of the matrix
  // (i.e. a submatrix that isn't the whole thing) will be recorded as an
  // access of type read/write.
  std::vector<Access> accesses;
  // true if this matrix is an input to the computation.
  bool is_input;
  // true if this matrix is an output of the computation.
  bool is_output;
  MatrixAccesses(): allocate_command(-1), deallocate_command(-1),
                    is_input(false), is_output(false) { }
};

You can see that we store more information than we do for the variables (i.e. more than just the std::vector<Access>). This is so that we can check whether the matrix is being allocated and deallocated appropriately.

The matrix accesses are computed by the function ComputeMatrixAccesses().

Checking the computation

After performing the analysis as described in the previous section, we can check the computation using class ComputationChecker. We list some of its code below; a glance at some of the names of private function members will indicate the kind of checks it is performing.

class ComputationChecker {
 public:
  ComputationChecker(const CheckComputationConfig &config,
                     const Nnet &nnet,
                     const ComputationRequest &request,
                     const NnetComputation &computation);
  void Check();
 private:
  // various dimension consistency checks and checks on properties.
  void CheckComputationIndexes() const;
  // make sure Propagate comes before kNoOpMarker and Backprop comes after it,
  // and that the value of forward_computation_end matches the position of
  // kNoOpMarker.
  void CheckComputationOrder() const;
  // checks for a situation where an undefined variable is read.
  void CheckComputationUndefined() const;
  // checks that all writes are done before reads.  details with implementation.
  void CheckComputationRewrite() const;
  // check matrix accesses make sense.
  void CheckComputationMatrixAccesses() const;
  ...

This checking code exists mainly to detect bugs in the compilation and optimization code.

Optimization

The optimization code has an options class that enables the user to turn off various of the specific optimizations it does. This is intended to help in debugging. The options class has a variable for each individual optimization:

struct NnetOptimizeConfig {
  bool optimize;  // setting this false disallow all optimization.
  bool propagate_in_place;
  bool backprop_in_place;
  bool remove_assignments;
  bool initialize_undefined;
  bool move_sizing_commands;
  ...
};

The top-level call to the optimization code is just a function call. We show some partial code for this function below:

void Optimize(const NnetOptimizeConfig &config,
              const Nnet &nnet,
              const ComputationRequest &request,
              NnetComputation *computation) {
  if (!config.optimize)
    return;
  bool changed = true;
  while (changed) {
    changed = false;
    VariableMergingOptimizer opt(config, nnet, request, computation);
    if (opt.MergeVariables())
      changed = true;
  }
  if (config.initialize_undefined)
    RemoveUnnecessaryZeroing(nnet, computation);

  if (config.move_sizing_commands)
    MoveSizingCommands(nnet, computation);
}

The VariableMergingOptimizer is a class that is responsible for merging variables together; it detects situations where there are two separate matrices that can be replaced with a single matrix.