ComputationCache Class Reference

Class ComputationCache is used inside class CachingOptimizingCompiler to cache previously computed computations. More...

#include <nnet-optimize-utils.h>

Collaboration diagram for ComputationCache:

Public Member Functions

 ComputationCache (int32 cache_capacity)
 
void Read (std::istream &is, bool binary)
 
void Write (std::ostream &os, bool binary) const
 
std::shared_ptr< const NnetComputationFind (const ComputationRequest &request)
 
std::shared_ptr< const NnetComputationInsert (const ComputationRequest &request, const NnetComputation *computation)
 
 ~ComputationCache ()
 
void Check (const Nnet &nnet) const
 

Private Types

typedef std::list< const ComputationRequest * > AqType
 
typedef unordered_map< const ComputationRequest *, std::pair< std::shared_ptr< const NnetComputation >, AqType::iterator >, ComputationRequestHasher, ComputationRequestPtrEqualCacheType
 

Private Attributes

std::mutex mutex_
 
int32 cache_capacity_
 
AqType access_queue_
 
CacheType computation_cache_
 

Detailed Description

Class ComputationCache is used inside class CachingOptimizingCompiler to cache previously computed computations.

The code was moved from class CachingOptimizingCompiler to this separate class for clarity when adding thread-safety functionality. It's OK to call Find() and Insert() from multiple threads without additional synchronization.

Definition at line 625 of file nnet-optimize-utils.h.

Member Typedef Documentation

◆ AqType

typedef std::list<const ComputationRequest*> AqType
private

Definition at line 668 of file nnet-optimize-utils.h.

◆ CacheType

typedef unordered_map<const ComputationRequest*, std::pair<std::shared_ptr<const NnetComputation>, AqType::iterator>, ComputationRequestHasher, ComputationRequestPtrEqual> CacheType
private

Definition at line 677 of file nnet-optimize-utils.h.

Constructor & Destructor Documentation

◆ ComputationCache()

ComputationCache ( int32  cache_capacity)

Definition at line 4969 of file nnet-optimize-utils.cc.

References KALDI_ASSERT.

4969  :
4970  cache_capacity_(cache_capacity) {
4971  KALDI_ASSERT(cache_capacity > 0);
4972 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ~ComputationCache()

Definition at line 5062 of file nnet-optimize-utils.cc.

References ComputationCache::computation_cache_.

5062  {
5063  CacheType::const_iterator iter = computation_cache_.begin(),
5064  end = computation_cache_.end();
5065  // We only need to explicitly delete the pointer to the ComputationRequest.
5066  // The pointers to Computation are deleted automatically by std::shared_ptr
5067  // when the reference count goes to zero.
5068  for (; iter != end; ++iter)
5069  delete iter->first;
5070 }

Member Function Documentation

◆ Check()

void Check ( const Nnet nnet) const

Definition at line 5037 of file nnet-optimize-utils.cc.

References ComputationChecker::Check(), and ComputationCache::computation_cache_.

Referenced by CachingOptimizingCompiler::ReadCache().

5037  {
5038  CacheType::const_iterator iter = computation_cache_.begin(),
5039  end = computation_cache_.end();
5040  // We only need to explicitly delete the pointer to the ComputationRequest.
5041  // The pointers to Computation are deleted automatically by std::shared_ptr
5042  // when the reference count goes to zero.
5043  for (; iter != end; ++iter) {
5044  const NnetComputation &computation = *(iter->second.first);
5045  CheckComputationOptions check_config;
5046  ComputationChecker checker(check_config, nnet, computation);
5047  checker.Check();
5048  }
5049 }
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.

◆ Find()

std::shared_ptr< const NnetComputation > Find ( const ComputationRequest request)

Definition at line 4951 of file nnet-optimize-utils.cc.

Referenced by CachingOptimizingCompiler::CompileInternal().

4952  {
4953  std::lock_guard<std::mutex> lock(mutex_);
4954 
4955  CacheType::iterator iter = computation_cache_.find(&in_request);
4956  if (iter == computation_cache_.end()) {
4957  return NULL;
4958  } else {
4959  std::shared_ptr<const NnetComputation> ans = iter->second.first;
4960  // Update access record by moving the accessed request to the end of the
4961  // access queue, which declares that it's the most recently used.
4963  iter->second.second);
4964  return ans;
4965  }
4966 }

◆ Insert()

std::shared_ptr< const NnetComputation > Insert ( const ComputationRequest request,
const NnetComputation computation 
)

Definition at line 4974 of file nnet-optimize-utils.cc.

References ComputationCache::access_queue_, ComputationCache::cache_capacity_, ComputationCache::computation_cache_, KALDI_ASSERT, and ComputationCache::mutex_.

Referenced by CachingOptimizingCompiler::CompileInternal(), and ComputationCache::Read().

4976  {
4977 
4978  std::lock_guard<std::mutex> lock(mutex_);
4979  if (static_cast<int32>(computation_cache_.size()) >= cache_capacity_) {
4980  // Cache has reached capacity; purge the least-recently-accessed request
4981  const CacheType::iterator iter =
4982  computation_cache_.find(access_queue_.front());
4983  KALDI_ASSERT(iter != computation_cache_.end());
4984  const ComputationRequest *request = iter->first;
4985  computation_cache_.erase(iter);
4986  delete request;
4987  // we don't need to delete the computation in iter->second.first, as the
4988  // shared_ptr takes care of that automatically.
4989  access_queue_.pop_front();
4990  }
4991 
4992  // Now insert the thing we need to insert. We'll own the pointer 'request' in
4993  // 'computation_cache_', so we need to allocate our own version.
4994  ComputationRequest *request = new ComputationRequest(request_in);
4995  // When we construct this shared_ptr, it takes ownership of the pointer
4996  // 'computation_in'.
4997  std::shared_ptr<const NnetComputation> computation(computation_in);
4998 
4999  AqType::iterator ait = access_queue_.insert(access_queue_.end(), request);
5000 
5001  std::pair<CacheType::iterator, bool> p = computation_cache_.insert(
5002  std::make_pair(request, std::make_pair(computation, ait)));
5003  if (!p.second) {
5004  // if p.second is false, this pair was not inserted because
5005  // a computation for the same computation-request already existed in
5006  // the map. This is possible in multi-threaded operations, if two
5007  // threads try to compile the same computation at the same time (only
5008  // one of them will successfully add it).
5009  // We need to erase the access-queue element that we just added, it's
5010  // no longer going to be needed.
5011  access_queue_.erase(ait);
5012  delete request;
5013  }
5014  return computation;
5015 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)

Definition at line 5018 of file nnet-optimize-utils.cc.

References ComputationCache::access_queue_, ComputationCache::computation_cache_, kaldi::nnet3::ExpectToken(), ComputationCache::Insert(), KALDI_ASSERT, kaldi::nnet2::NnetComputation(), ComputationRequest::Read(), NnetComputation::Read(), and kaldi::ReadBasicType().

Referenced by CachingOptimizingCompiler::ReadCache().

5018  {
5019  // Note: the object on disk doesn't have tokens like "<ComputationCache>"
5020  // and "</ComputationCache>" for back-compatibility reasons.
5021  int32 computation_cache_size;
5022  ExpectToken(is, binary, "<ComputationCacheSize>");
5023  ReadBasicType(is, binary, &computation_cache_size);
5024  KALDI_ASSERT(computation_cache_size >= 0);
5025  computation_cache_.clear();
5026  access_queue_.clear();
5027  ExpectToken(is, binary, "<ComputationCache>");
5028  for (size_t c = 0; c < computation_cache_size; c++) {
5029  ComputationRequest request;
5030  request.Read(is, binary);
5031  NnetComputation *computation = new NnetComputation();
5032  computation->Read(is, binary);
5033  Insert(request, computation);
5034  }
5035 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
std::shared_ptr< const NnetComputation > Insert(const ComputationRequest &request, const NnetComputation *computation)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const

Definition at line 5051 of file nnet-optimize-utils.cc.

References ComputationCache::computation_cache_, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by CachingOptimizingCompiler::WriteCache().

5051  {
5052  WriteToken(os, binary, "<ComputationCacheSize>");
5053  WriteBasicType(os, binary, static_cast<int32>(computation_cache_.size()));
5054  WriteToken(os, binary, "<ComputationCache>");
5055  for (CacheType::const_iterator iter = computation_cache_.begin();
5056  iter != computation_cache_.end(); ++iter) {
5057  iter->first->Write(os, binary);
5058  iter->second.first->Write(os, binary);
5059  }
5060 }
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ access_queue_

AqType access_queue_
private

Definition at line 669 of file nnet-optimize-utils.h.

Referenced by ComputationCache::Insert(), and ComputationCache::Read().

◆ cache_capacity_

int32 cache_capacity_
private

Definition at line 660 of file nnet-optimize-utils.h.

Referenced by ComputationCache::Insert().

◆ computation_cache_

◆ mutex_

std::mutex mutex_
private

Definition at line 658 of file nnet-optimize-utils.h.

Referenced by ComputationCache::Insert().


The documentation for this class was generated from the following files: