CachingOptimizingCompiler Class Reference

This class enables you to do the compilation and optimization in one call, and also ensures that if the ComputationRequest is identical to the previous one, the compilation process is not repeated. More...

#include <nnet-optimize.h>

Collaboration diagram for CachingOptimizingCompiler:

Public Member Functions

 CachingOptimizingCompiler (const Nnet &nnet, const CachingOptimizingCompilerOptions config=CachingOptimizingCompilerOptions())
 
 CachingOptimizingCompiler (const Nnet &nnet, const NnetOptimizeOptions &opt_config, const CachingOptimizingCompilerOptions config=CachingOptimizingCompilerOptions())
 Note: nnet is retained as a const reference but opt_config is copied. More...
 
 ~CachingOptimizingCompiler ()
 
std::shared_ptr< const NnetComputationCompile (const ComputationRequest &request)
 Does the compilation and returns a const pointer to the result, which is owned by this class, not the caller. More...
 
void ReadCache (std::istream &is, bool binary)
 
void WriteCache (std::ostream &os, bool binary)
 
void GetSimpleNnetContext (int32 *nnet_left_context, int32 *nnet_right_context)
 

Private Member Functions

std::shared_ptr< const NnetComputationCompileInternal (const ComputationRequest &request)
 
std::shared_ptr< const NnetComputationCompileAndCache (const ComputationRequest &request)
 
const NnetComputationCompileViaShortcut (const ComputationRequest &request)
 
const NnetComputationCompileNoShortcut (const ComputationRequest &request)
 

Private Attributes

const Nnetnnet_
 
CachingOptimizingCompilerOptions config_
 
NnetOptimizeOptions opt_config_
 
double seconds_taken_total_
 
double seconds_taken_compile_
 
double seconds_taken_optimize_
 
double seconds_taken_expand_
 
double seconds_taken_check_
 
double seconds_taken_indexes_
 
double seconds_taken_io_
 
ComputationCache cache_
 
int32 nnet_left_context_
 
int32 nnet_right_context_
 

Detailed Description

This class enables you to do the compilation and optimization in one call, and also ensures that if the ComputationRequest is identical to the previous one, the compilation process is not repeated.

It is safe to call Compile() from multiple parallel threads without additional synchronization; synchronization is managed internally by class ComputationCache.

Definition at line 219 of file nnet-optimize.h.

Constructor & Destructor Documentation

◆ CachingOptimizingCompiler() [1/2]

Definition at line 635 of file nnet-optimize.cc.

637  :
638  nnet_(nnet), config_(config),
642  seconds_taken_io_(0.0), cache_(config.cache_capacity),
CachingOptimizingCompilerOptions config_

◆ CachingOptimizingCompiler() [2/2]

Note: nnet is retained as a const reference but opt_config is copied.

Definition at line 645 of file nnet-optimize.cc.

648  :
649  nnet_(nnet), config_(config), opt_config_(opt_config),
653  seconds_taken_io_(0.0), cache_(config.cache_capacity),
CachingOptimizingCompilerOptions config_

◆ ~CachingOptimizingCompiler()

Definition at line 695 of file nnet-optimize.cc.

References KALDI_LOG, CachingOptimizingCompiler::seconds_taken_check_, CachingOptimizingCompiler::seconds_taken_compile_, CachingOptimizingCompiler::seconds_taken_expand_, CachingOptimizingCompiler::seconds_taken_indexes_, CachingOptimizingCompiler::seconds_taken_io_, CachingOptimizingCompiler::seconds_taken_optimize_, and CachingOptimizingCompiler::seconds_taken_total_.

695  {
696  if (seconds_taken_total_ > 0.0 || seconds_taken_io_ > 0.0) {
697  std::ostringstream os;
698  double seconds_taken_misc = seconds_taken_total_ - seconds_taken_compile_
701  os << std::setprecision(3) << seconds_taken_total_
702  << " seconds taken in nnet3 compilation total (breakdown: "
703  << seconds_taken_compile_ << " compilation, "
704  << seconds_taken_optimize_ << " optimization, "
705  << seconds_taken_expand_ << " shortcut expansion, "
706  << seconds_taken_check_ << " checking, "
707  << seconds_taken_indexes_ << " computing indexes, "
708  << seconds_taken_misc << " misc.) + "
709  << seconds_taken_io_ << " I/O.";
710  KALDI_LOG << os.str();
711  // note: the leftover amount is misc things like hashing and == comparisons on
712  // computation-requests, and calling RequestIsDecomposable().
713  }
714 }
#define KALDI_LOG
Definition: kaldi-error.h:153

Member Function Documentation

◆ Compile()

std::shared_ptr< const NnetComputation > Compile ( const ComputationRequest request)

Does the compilation and returns a const pointer to the result, which is owned by this class, not the caller.

It calls ComputeCudaIndexes() for you, because you wouldn't be able to do this on a const object.

Note: this used to return 'const NnetComputation*'. If you get a compilation failure, just replace 'const NnetComputation*' with 'std::shared_ptr<const NnetComputation>' in the calling code.

Definition at line 716 of file nnet-optimize.cc.

References CachingOptimizingCompiler::CompileInternal(), Timer::Elapsed(), and CachingOptimizingCompiler::seconds_taken_total_.

Referenced by NnetLdaStatsAccumulator::AccStats(), BatchedXvectorComputer::BatchedXvectorComputer(), NnetComputerFromEg::Compute(), NnetDiscriminativeComputeObjf::Compute(), NnetChainComputeProb::Compute(), NnetComputeProb::Compute(), DecodableNnetSimple::DoNnetComputation(), NnetBatchComputer::GetComputation(), kaldi::nnet3::RunNnetComputation(), NnetChainTrainer::Train(), NnetDiscriminativeTrainer::Train(), NnetTrainer::Train(), kaldi::nnet3::UnitTestNnetModelDerivatives(), and kaldi::nnet3::UnitTestNnetOptimizeWithOptions().

717  {
718  Timer timer;
719  std::shared_ptr<const NnetComputation> ans = CompileInternal(in_request);
720  seconds_taken_total_ += timer.Elapsed();
721  return ans;
722 }
std::shared_ptr< const NnetComputation > CompileInternal(const ComputationRequest &request)

◆ CompileAndCache()

std::shared_ptr<const NnetComputation> CompileAndCache ( const ComputationRequest request)
private

◆ CompileInternal()

std::shared_ptr< const NnetComputation > CompileInternal ( const ComputationRequest request)
private

Definition at line 724 of file nnet-optimize.cc.

References CachingOptimizingCompiler::cache_, CachingOptimizingCompiler::CompileNoShortcut(), CachingOptimizingCompiler::CompileViaShortcut(), CachingOptimizingCompiler::config_, ComputationCache::Find(), ComputationCache::Insert(), KALDI_ASSERT, and CachingOptimizingCompilerOptions::use_shortcut.

Referenced by CachingOptimizingCompiler::Compile(), and CachingOptimizingCompiler::CompileViaShortcut().

725  {
726  std::shared_ptr<const NnetComputation> ans = cache_.Find(request);
727  if (ans != NULL) {
728  return ans;
729  } else {
730  const NnetComputation *computation = NULL;
731  if (config_.use_shortcut)
732  computation = CompileViaShortcut(request);
733  if (computation == NULL)
734  computation = CompileNoShortcut(request);
735  KALDI_ASSERT(computation != NULL);
736  return cache_.Insert(request, computation);
737  }
738 }
const NnetComputation * CompileNoShortcut(const ComputationRequest &request)
const NnetComputation * CompileViaShortcut(const ComputationRequest &request)
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
std::shared_ptr< const NnetComputation > Find(const ComputationRequest &request)
std::shared_ptr< const NnetComputation > Insert(const ComputationRequest &request, const NnetComputation *computation)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
CachingOptimizingCompilerOptions config_

◆ CompileNoShortcut()

const NnetComputation * CompileNoShortcut ( const ComputationRequest request)
private

Definition at line 741 of file nnet-optimize.cc.

References ComputationChecker::Check(), CheckComputationOptions::check_rewrite, NnetComputation::ComputeCudaIndexes(), Compiler::CreateComputation(), Timer::Elapsed(), kaldi::GetVerboseLevel(), KALDI_LOG, kaldi::nnet3::MaxOutputTimeInRequest(), CachingOptimizingCompiler::nnet_, kaldi::nnet2::NnetComputation(), CachingOptimizingCompiler::opt_config_, kaldi::nnet3::Optimize(), ComputationRequest::Print(), NnetComputation::Print(), CachingOptimizingCompiler::seconds_taken_check_, CachingOptimizingCompiler::seconds_taken_compile_, CachingOptimizingCompiler::seconds_taken_indexes_, and CachingOptimizingCompiler::seconds_taken_optimize_.

Referenced by CachingOptimizingCompiler::CompileInternal().

742  {
743 
744  Compiler compiler(request, nnet_);
745  // note: 'opts' only contains 'output_debug_info', which is true by default.
746  // There may be situations where we'd prefer not to keep it, for speed.
747  CompilerOptions opts;
748  NnetComputation *computation = new NnetComputation;
749 
750  {
751  Timer timer;
752  compiler.CreateComputation(opts, computation);
753  seconds_taken_compile_ += timer.Elapsed();
754  }
755 
756  int32 verbose_cutoff = 4;
757  if (GetVerboseLevel() >= verbose_cutoff) {
758  std::ostringstream os1;
759  request.Print(os1);
760  KALDI_LOG << "Computation request is " << os1.str();
761  std::ostringstream os2;
762  computation->Print(os2, nnet_);
763  KALDI_LOG << "Generated computation is: " << os2.str();
764  }
765 
766  { // some checking. Note: there may come a time when we might
767  // prefer to disable this checking.
768  Timer timer;
769  CheckComputationOptions check_config;
770  // we can do the rewrite check since it's before optimization.
771  check_config.check_rewrite = true;
772  ComputationChecker checker(check_config, nnet_, *computation);
773  checker.Check();
774  seconds_taken_check_ += timer.Elapsed();
775  }
776 
777  {
778  Timer timer;
780  MaxOutputTimeInRequest(request),
781  computation);
782  seconds_taken_optimize_ += timer.Elapsed();
783  }
784 
785  if (GetVerboseLevel() >= verbose_cutoff) {
786  std::ostringstream os;
787  computation->Print(os, nnet_);
788  KALDI_LOG << "Optimized computation is: " << os.str();
789  }
790 
791  { // check the computation again.
792  Timer timer;
793  CheckComputationOptions check_config;
794  ComputationChecker checker(check_config, nnet_, *computation);
795  checker.Check();
796  seconds_taken_check_ += timer.Elapsed();
797  }
798 
799  {
800  Timer timer;
801  computation->ComputeCudaIndexes();
802  seconds_taken_indexes_ += timer.Elapsed();
803  }
804  return computation;
805 }
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
kaldi::int32 int32
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
int32 MaxOutputTimeInRequest(const ComputationRequest &request)
void Optimize(const NnetOptimizeOptions &config, const Nnet &nnet, int32 max_output_time_in_request, NnetComputation *computation)
This is the top-level function for optimizing a computation.
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ CompileViaShortcut()

const NnetComputation * CompileViaShortcut ( const ComputationRequest request)
private

Definition at line 808 of file nnet-optimize.cc.

References kaldi::nnet3::CheckComputation(), CachingOptimizingCompiler::CompileInternal(), NnetComputation::ComputeCudaIndexes(), Timer::Elapsed(), kaldi::nnet3::ExpandComputation(), kaldi::GetVerboseLevel(), ComputationRequest::misc_info, CachingOptimizingCompiler::nnet_, kaldi::nnet2::NnetComputation(), kaldi::nnet3::RequestIsDecomposable(), CachingOptimizingCompiler::seconds_taken_expand_, and CachingOptimizingCompiler::seconds_taken_indexes_.

Referenced by CachingOptimizingCompiler::CompileInternal().

809  {
810  int32 num_n_values;
811  ComputationRequest mini_request;
812  if (!RequestIsDecomposable(request, &mini_request, &num_n_values))
813  return NULL;
814 
815  // By invoking CompileInternal() on the mini request, we go through the same
816  // caching process as for any externally requested computation.
817  std::shared_ptr<const NnetComputation> mini_computation =
818  CompileInternal(mini_request);
819 
820  // note: by default we always create debug_info, even in regular compilation.
821  // (e.g. it defaults to true in CompilerOptions). If it really seems to be a
822  // significant overhead, we can revisit this at some point in future.
823  bool need_debug_info = true;
824 
825 
826  NnetComputation *ans = new NnetComputation();
827 
828  {
829  Timer timer;
830  ExpandComputation(nnet_, request.misc_info, *mini_computation,
831  need_debug_info, num_n_values, ans);
832  seconds_taken_expand_ += timer.Elapsed();
833  }
834  if (GetVerboseLevel() >= 3) {
835  CheckComputation(nnet_, *ans, false);
836  }
837 
838  {
839  Timer timer;
840  ans->ComputeCudaIndexes();
841  seconds_taken_indexes_ += timer.Elapsed();
842  }
843  return ans;
844 }
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
bool RequestIsDecomposable(const ComputationRequest &request, ComputationRequest *mini_request, int32 *num_n_values)
This function, used in &#39;shortcut&#39; compilation where we first compile a smaller computation with the s...
kaldi::int32 int32
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
void ExpandComputation(const Nnet &nnet, const MiscComputationInfo &misc_info, const NnetComputation &computation, bool need_debug_info, int32 num_n_values, NnetComputation *expanded_computation)
This function is used in &#39;shortcut&#39; compilation to expand a computation that has been compiled for ex...
void CheckComputation(const Nnet &nnet, const NnetComputation &computation, bool check_rewrite)
This is a convenience interface for class ComputationChecker.
std::shared_ptr< const NnetComputation > CompileInternal(const ComputationRequest &request)

◆ GetSimpleNnetContext()

void GetSimpleNnetContext ( int32 nnet_left_context,
int32 nnet_right_context 
)

Definition at line 656 of file nnet-optimize.cc.

References kaldi::nnet3::ComputeSimpleNnetContext(), CachingOptimizingCompiler::nnet_, CachingOptimizingCompiler::nnet_left_context_, and CachingOptimizingCompiler::nnet_right_context_.

Referenced by DecodableNnetSimple::DecodableNnetSimple().

657  {
658  if (nnet_left_context_ == -1) {
661  }
662  *nnet_left_context = nnet_left_context_;
663  *nnet_right_context = nnet_right_context_;
664 }
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146

◆ ReadCache()

void ReadCache ( std::istream &  is,
bool  binary 
)

Definition at line 666 of file nnet-optimize.cc.

References CachingOptimizingCompiler::cache_, ComputationCache::Check(), Timer::Elapsed(), kaldi::GetVerboseLevel(), CachingOptimizingCompiler::nnet_, CachingOptimizingCompiler::opt_config_, NnetOptimizeOptions::Read(), ComputationCache::Read(), CachingOptimizingCompiler::seconds_taken_check_, CachingOptimizingCompiler::seconds_taken_io_, and CachingOptimizingCompiler::seconds_taken_total_.

Referenced by main(), NnetChainTrainer::NnetChainTrainer(), NnetDiscriminativeTrainer::NnetDiscriminativeTrainer(), and NnetTrainer::NnetTrainer().

666  {
667  {
668  Timer timer;
669  NnetOptimizeOptions opt_config_cached;
670  opt_config_cached.Read(is, binary);
671  // we won't read cached computations if any optimize option has been changed.
672  if (!(opt_config_ == opt_config_cached))
673  return;
674  cache_.Read(is, binary);
675  seconds_taken_io_ += timer.Elapsed();
676  }
677  if (GetVerboseLevel() >= 2) {
678  Timer timer;
679  cache_.Check(nnet_);
680  seconds_taken_check_ += timer.Elapsed();
681  // we consider the check time part of the total time... this is very
682  // arbitrary but it only affects printed times-taken.
683  seconds_taken_total_ += timer.Elapsed();
684  }
685 
686 }
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void Read(std::istream &is, bool binary)
void Check(const Nnet &nnet) const

◆ WriteCache()

void WriteCache ( std::ostream &  os,
bool  binary 
)

Member Data Documentation

◆ cache_

◆ config_

Definition at line 289 of file nnet-optimize.h.

Referenced by CachingOptimizingCompiler::CompileInternal().

◆ nnet_

◆ nnet_left_context_

int32 nnet_left_context_
private

Definition at line 305 of file nnet-optimize.h.

Referenced by CachingOptimizingCompiler::GetSimpleNnetContext().

◆ nnet_right_context_

int32 nnet_right_context_
private

Definition at line 306 of file nnet-optimize.h.

Referenced by CachingOptimizingCompiler::GetSimpleNnetContext().

◆ opt_config_

◆ seconds_taken_check_

◆ seconds_taken_compile_

double seconds_taken_compile_
private

◆ seconds_taken_expand_

double seconds_taken_expand_
private

◆ seconds_taken_indexes_

◆ seconds_taken_io_

◆ seconds_taken_optimize_

double seconds_taken_optimize_
private

◆ seconds_taken_total_


The documentation for this class was generated from the following files: