MemoryCompressionOptimizer Class Reference

This class is used in the function OptimizeMemoryCompression(), once we determine that there is some potential to do memory compression for this computation. More...

Collaboration diagram for MemoryCompressionOptimizer:

Classes

struct  MatrixCompressInfo
 

Public Member Functions

 MemoryCompressionOptimizer (const Nnet &nnet, int32 memory_compression_level, int32 middle_command, NnetComputation *computation)
 
void Optimize ()
 

Private Member Functions

void ProcessMatrix (int32 m)
 
void ModifyComputation ()
 

Private Attributes

std::vector< MatrixCompressInfocompress_info_
 
const Nnetnnet_
 
int32 memory_compression_level_
 
int32 middle_command_
 
NnetComputationcomputation_
 
Analyzer analyzer_
 

Detailed Description

This class is used in the function OptimizeMemoryCompression(), once we determine that there is some potential to do memory compression for this computation.

Definition at line 4690 of file nnet-optimize-utils.cc.

Constructor & Destructor Documentation

◆ MemoryCompressionOptimizer()

MemoryCompressionOptimizer ( const Nnet nnet,
int32  memory_compression_level,
int32  middle_command,
NnetComputation computation 
)
inline
Parameters
[in]nnetThe neural net the computation is for.
[in]memory_compression_level.The level of compression: 0 = no compression (the constructor should not be called with this value). 1 = compression that doesn't affect the results (but still takes time). 2 = compression that affects the results only very slightly 3 = compression that affects the results a little more.
[in]middle_commandMust be the command-index of the command of type kNoOperationMarker in 'computation'.
[in,out]computationThe computation we're optimizing.

Definition at line 4703 of file nnet-optimize-utils.cc.

References kaldi::nnet3::Optimize().

Member Function Documentation

◆ ModifyComputation()

void ModifyComputation ( )
private

Definition at line 4769 of file nnet-optimize-utils.cc.

References MemoryCompressionOptimizer::MatrixCompressInfo::compression_command_index, MemoryCompressionOptimizer::MatrixCompressInfo::compression_type, DerivativeTimeLimiter::computation_, NnetComputation::GetWholeSubmatrices(), rnnlm::i, kaldi::nnet3::InsertCommands(), kaldi::nnet3::kCompressMatrix, kaldi::nnet3::kDecompressMatrix, MemoryCompressionOptimizer::MatrixCompressInfo::m, MemoryCompressionOptimizer::MatrixCompressInfo::range, MemoryCompressionOptimizer::MatrixCompressInfo::truncate, and MemoryCompressionOptimizer::MatrixCompressInfo::uncompression_command_index.

4769  {
4770  // whole_submatrices[m] is the submatrix-index of the submatrix that
4771  // represents the whole of matrix m.
4772  std::vector<int32> whole_submatrices;
4773  computation_->GetWholeSubmatrices(&whole_submatrices);
4774 
4775  // 'pairs_to_insert' will be a list of pairs (command-index, command),
4776  // meaning: (command-index just before which to insert this command; command
4777  // to insert).
4778  std::vector<std::pair<int32, NnetComputation::Command> >
4779  pairs_to_insert;
4780  pairs_to_insert.reserve(compress_info_.size() * 2);
4781  for (size_t i = 0; i < compress_info_.size(); i++) {
4782  const MatrixCompressInfo &info = compress_info_[i];
4783  int32 s = whole_submatrices[info.m];
4784  // below we use compression_command_index + 1 because we want the
4785  // compression to go after the command in 'info.compression_command_index'
4786  // (which might be, for instance, a forward propagation command).
4787  std::pair<int32, NnetComputation::Command> p1(
4788  info.compression_command_index + 1,
4789  NnetComputation::Command(info.range, kCompressMatrix,
4790  s, static_cast<int32>(info.compression_type),
4791  info.truncate ? 1 : 0));
4792  pairs_to_insert.push_back(p1);
4793  std::pair<int32, NnetComputation::Command> p2(
4794  info.uncompression_command_index,
4795  NnetComputation::Command(1.0, kDecompressMatrix, s));
4796  pairs_to_insert.push_back(p2);
4797  }
4798  InsertCommands(&pairs_to_insert,
4799  computation_);
4800 }
std::vector< MatrixCompressInfo > compress_info_
void InsertCommands(std::vector< std::pair< int32, NnetComputation::Command > > *new_commands, NnetComputation *computation)
Inserts commands into the computation at the requested places.
kaldi::int32 int32
void GetWholeSubmatrices(std::vector< int32 > *whole_submatrices) const

◆ Optimize()

void Optimize ( )

Definition at line 4803 of file nnet-optimize-utils.cc.

References DerivativeTimeLimiter::computation_, NnetComputation::matrices, and DerivativeTimeLimiter::nnet_.

Referenced by kaldi::nnet3::OptimizeMemoryCompression().

4803  {
4805  // note: matrix zero is not really a matrix.
4806  int32 num_matrices = computation_->matrices.size();
4807  for (int32 m = 1; m < num_matrices; m++)
4808  ProcessMatrix(m);
4809  if (!compress_info_.empty())
4811 }
std::vector< MatrixCompressInfo > compress_info_
kaldi::int32 int32
std::vector< MatrixInfo > matrices
void Init(const Nnet &nnet, const NnetComputation &computation)

◆ ProcessMatrix()

void ProcessMatrix ( int32  m)
private

Definition at line 4813 of file nnet-optimize-utils.cc.

References Access::access_type, NnetComputation::Command::arg1, Access::command_index, NnetComputation::Command::command_type, NnetComputation::commands, DerivativeTimeLimiter::computation_, Nnet::GetComponent(), KALDI_ASSERT, kaldi::nnet3::kBackprop, kaldi::kCompressedMatrixInt16, kaldi::kCompressedMatrixUint8, kaldi::nnet3::kReadAccess, DerivativeTimeLimiter::nnet_, and Component::Type().

4813  {
4814  if (analyzer_.matrix_accesses[m].is_output) {
4815  return; // We can't do this optimization for matrices that are going to be
4816  // output to the user.
4817  }
4818 
4819  // 'accesses' list the commands that access this matrix.
4820  const std::vector<Access> &accesses = analyzer_.matrix_accesses[m].accesses;
4821  // the 'kReadAccess' below is actually a don't-care This is just
4822  // to find the position in 'accesses' that corresponds to command-index
4823  // 'middle_command'.
4824  Access middle_access(middle_command_, kReadAccess);
4825  std::vector<Access>::const_iterator iter = std::lower_bound(accesses.begin(),
4826  accesses.end(),
4827  middle_access);
4828  // At this point, 'iter' points to the first access in 'accesses'
4829  // whose command index is >= 'middle_command_' (which separates the forward
4830  // and backward passes), or accesses.end() if this matrix was not
4831  // accessed during the backward pass.
4832  if (iter == accesses.end()) {
4833  return; // There is nothing to do: this matrix was not accessed during the
4834  // backward pass.
4835  }
4836  if (iter == accesses.begin()) {
4837  return; // There is nothing to do: this matrix was not accessed during the
4838  // forward pass.
4839  }
4840  // 'backward_access' is the first access of the matrix in the backward
4841  // pass of the computation, and
4842  // 'forward_access' is the last access of the matrix in the forward pass
4843  // of the computation.
4844  const Access &backward_access = iter[0],
4845  &forward_access = iter[-1];
4846  KALDI_ASSERT(forward_access.command_index < middle_command_ &&
4847  backward_access.command_index > middle_command_);
4848 
4849  // 'backward_access_is_last_access' is going to be set to true if
4850  // 'backward_access' is the last command to access the matrix (apart from
4851  // deallocation or matrix-swap commands, which don't show up in the list of
4852  // accesses).
4853  bool backward_access_is_last_access = (accesses.end() == iter + 1);
4854 
4855  int32 backward_command_index = backward_access.command_index,
4856  forward_command_index = forward_access.command_index;
4857  NnetComputation::Command
4858  &backward_command = computation_->commands[backward_command_index];
4859 
4860  if (memory_compression_level_ >= 1 &&
4861  backward_access_is_last_access &&
4862  backward_access.access_type == kReadAccess &&
4863  backward_command.command_type == kBackprop) {
4864  int32 component_index = backward_command.arg1;
4865  const Component *component = nnet_.GetComponent(component_index);
4866  // this is potentially a candidate for our optimization for ReLU units,
4867  // where we only need to store the sign.
4868  if (component->Type() == "RectifiedLinearComponent") {
4869  compress_info_.push_back(
4870  MatrixCompressInfo(m, forward_command_index,
4871  backward_command_index,
4873  true));
4874  return;
4875  }
4876  }
4877 
4878  // If memory_compression_level >= 2 (an "intermediate" level of compression),
4879  // then we'll consider compressing quantities using 16 bits in the range
4880  // [-10, 10]. Because of the way this compression works, exact zero will
4881  // still be uncompressed as exact zero, so even if this is the output
4882  // of a ReLU, it's OK. (Having a few derivatives zero for ReLU outputs
4883  // that were very close to zero is OK.)
4884  if (memory_compression_level_ >= 2) {
4885  compress_info_.push_back(
4886  MatrixCompressInfo(m, forward_command_index,
4887  backward_command_index,
4888  kCompressedMatrixInt16, 10.0,
4889  true));
4890  return;
4891  }
4892 
4893  // TODO: later maybe implement something for memory compression level = 3.
4894 }
std::vector< MatrixCompressInfo > compress_info_
kaldi::int32 int32
std::vector< Command > commands
Component * GetComponent(int32 c)
Return component indexed c. Not a copy; not owned by caller.
Definition: nnet-nnet.cc:150
std::vector< MatrixAccesses > matrix_accesses
Definition: nnet-analyze.h:298
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Data Documentation

◆ analyzer_

Analyzer analyzer_
private

Definition at line 4765 of file nnet-optimize-utils.cc.

◆ compress_info_

std::vector<MatrixCompressInfo> compress_info_
private

Definition at line 4759 of file nnet-optimize-utils.cc.

◆ computation_

NnetComputation* computation_
private

Definition at line 4764 of file nnet-optimize-utils.cc.

◆ memory_compression_level_

int32 memory_compression_level_
private

Definition at line 4762 of file nnet-optimize-utils.cc.

◆ middle_command_

int32 middle_command_
private

Definition at line 4763 of file nnet-optimize-utils.cc.

◆ nnet_

const Nnet& nnet_
private

Definition at line 4761 of file nnet-optimize-utils.cc.


The documentation for this class was generated from the following file: