All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CachingOptimizingCompiler Class Reference

This class enables you to do the compilation and optimization in one call, and also ensures that if the ComputationRequest is identical to the previous one, the compilation process is not repeated. More...

#include <nnet-optimize.h>

Collaboration diagram for CachingOptimizingCompiler:

Public Member Functions

 CachingOptimizingCompiler (const Nnet &nnet, const CachingOptimizingCompilerOptions config=CachingOptimizingCompilerOptions())
 
 CachingOptimizingCompiler (const Nnet &nnet, const NnetOptimizeOptions &opt_config, const CachingOptimizingCompilerOptions config=CachingOptimizingCompilerOptions())
 Note: nnet is retained as a const reference but opt_config is copied. More...
 
 ~CachingOptimizingCompiler ()
 
std::shared_ptr< const
NnetComputation
Compile (const ComputationRequest &request)
 Does the compilation and returns a const pointer to the result, which is owned by this class, not the caller. More...
 
void ReadCache (std::istream &is, bool binary)
 
void WriteCache (std::ostream &os, bool binary)
 

Private Member Functions

std::shared_ptr< const
NnetComputation
CompileInternal (const ComputationRequest &request)
 
std::shared_ptr< const
NnetComputation
CompileAndCache (const ComputationRequest &request)
 
const NnetComputationCompileViaShortcut (const ComputationRequest &request)
 
const NnetComputationCompileNoShortcut (const ComputationRequest &request)
 

Private Attributes

const Nnetnnet_
 
CachingOptimizingCompilerOptions config_
 
NnetOptimizeOptions opt_config_
 
double seconds_taken_total_
 
double seconds_taken_compile_
 
double seconds_taken_optimize_
 
double seconds_taken_expand_
 
double seconds_taken_check_
 
double seconds_taken_indexes_
 
double seconds_taken_io_
 
ComputationCache cache_
 

Detailed Description

This class enables you to do the compilation and optimization in one call, and also ensures that if the ComputationRequest is identical to the previous one, the compilation process is not repeated.

It is safe to call Compile() from multiple parallel threads without additional synchronization; synchronization is managed internally by class ComputationCache.

Definition at line 219 of file nnet-optimize.h.

Constructor & Destructor Documentation

Definition at line 634 of file nnet-optimize.cc.

Note: nnet is retained as a const reference but opt_config is copied.

Definition at line 643 of file nnet-optimize.cc.

646  :
647  nnet_(nnet), config_(config), opt_config_(opt_config),
651  seconds_taken_io_(0.0), cache_(config.cache_capacity) { }
CachingOptimizingCompilerOptions config_

Definition at line 683 of file nnet-optimize.cc.

References KALDI_LOG, CachingOptimizingCompiler::seconds_taken_check_, CachingOptimizingCompiler::seconds_taken_compile_, CachingOptimizingCompiler::seconds_taken_expand_, CachingOptimizingCompiler::seconds_taken_indexes_, CachingOptimizingCompiler::seconds_taken_io_, CachingOptimizingCompiler::seconds_taken_optimize_, and CachingOptimizingCompiler::seconds_taken_total_.

683  {
684  if (seconds_taken_total_ > 0.0 || seconds_taken_io_ > 0.0) {
685  std::ostringstream os;
686  double seconds_taken_misc = seconds_taken_total_ - seconds_taken_compile_
689  os << std::setprecision(3) << seconds_taken_total_
690  << " seconds taken in nnet3 compilation total (breakdown: "
691  << seconds_taken_compile_ << " compilation, "
692  << seconds_taken_optimize_ << " optimization, "
693  << seconds_taken_expand_ << " shortcut expansion, "
694  << seconds_taken_check_ << " checking, "
695  << seconds_taken_indexes_ << " computing indexes, "
696  << seconds_taken_misc << " misc.) + "
697  << seconds_taken_io_ << " I/O.";
698  KALDI_LOG << os.str();
699  // note: the leftover amount is misc things like hashing and == comparisons on
700  // computation-requests, and calling RequestIsDecomposable().
701  }
702 }
#define KALDI_LOG
Definition: kaldi-error.h:133

Member Function Documentation

std::shared_ptr< const NnetComputation > Compile ( const ComputationRequest request)

Does the compilation and returns a const pointer to the result, which is owned by this class, not the caller.

It calls ComputeCudaIndexes() for you, because you wouldn't be able to do this on a const object.

Note: this used to return 'const NnetComputation*'. If you get a compilation failure, just replace 'const NnetComputation*' with 'std::shared_ptr<const NnetComputation>' in the calling code.

Definition at line 704 of file nnet-optimize.cc.

References CachingOptimizingCompiler::CompileInternal(), Timer::Elapsed(), and CachingOptimizingCompiler::seconds_taken_total_.

Referenced by NnetLdaStatsAccumulator::AccStats(), NnetComputerFromEg::Compute(), NnetDiscriminativeComputeObjf::Compute(), NnetChainComputeProb::Compute(), NnetComputeProb::Compute(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::RunNnetComputation(), NnetChainTrainer::Train(), NnetDiscriminativeTrainer::Train(), NnetTrainer::Train(), kaldi::nnet3::UnitTestNnetModelDerivatives(), and kaldi::nnet3::UnitTestNnetOptimizeWithOptions().

705  {
706  Timer timer;
707  std::shared_ptr<const NnetComputation> ans = CompileInternal(in_request);
708  seconds_taken_total_ += timer.Elapsed();
709  return ans;
710 }
std::shared_ptr< const NnetComputation > CompileInternal(const ComputationRequest &request)
std::shared_ptr<const NnetComputation> CompileAndCache ( const ComputationRequest request)
private
std::shared_ptr< const NnetComputation > CompileInternal ( const ComputationRequest request)
private

Definition at line 712 of file nnet-optimize.cc.

References CachingOptimizingCompiler::cache_, CachingOptimizingCompiler::CompileNoShortcut(), CachingOptimizingCompiler::CompileViaShortcut(), CachingOptimizingCompiler::config_, ComputationCache::Find(), ComputationCache::Insert(), KALDI_ASSERT, and CachingOptimizingCompilerOptions::use_shortcut.

Referenced by CachingOptimizingCompiler::Compile(), and CachingOptimizingCompiler::CompileViaShortcut().

713  {
714  std::shared_ptr<const NnetComputation> ans = cache_.Find(request);
715  if (ans != NULL) {
716  return ans;
717  } else {
718  const NnetComputation *computation = NULL;
719  if (config_.use_shortcut)
720  computation = CompileViaShortcut(request);
721  if (computation == NULL)
722  computation = CompileNoShortcut(request);
723  KALDI_ASSERT(computation != NULL);
724  return cache_.Insert(request, computation);
725  }
726 }
const NnetComputation * CompileNoShortcut(const ComputationRequest &request)
const NnetComputation * CompileViaShortcut(const ComputationRequest &request)
std::shared_ptr< const NnetComputation > Find(const ComputationRequest &request)
std::shared_ptr< const NnetComputation > Insert(const ComputationRequest &request, const NnetComputation *computation)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
CachingOptimizingCompilerOptions config_
const NnetComputation * CompileNoShortcut ( const ComputationRequest request)
private

Definition at line 729 of file nnet-optimize.cc.

References ComputationChecker::Check(), CheckComputationOptions::check_rewrite, NnetComputation::ComputeCudaIndexes(), Compiler::CreateComputation(), Timer::Elapsed(), kaldi::GetVerboseLevel(), KALDI_LOG, kaldi::nnet3::MaxOutputTimeInRequest(), CachingOptimizingCompiler::nnet_, kaldi::nnet2::NnetComputation(), CachingOptimizingCompiler::opt_config_, kaldi::nnet3::Optimize(), ComputationRequest::Print(), NnetComputation::Print(), CachingOptimizingCompiler::seconds_taken_check_, CachingOptimizingCompiler::seconds_taken_compile_, CachingOptimizingCompiler::seconds_taken_indexes_, and CachingOptimizingCompiler::seconds_taken_optimize_.

Referenced by CachingOptimizingCompiler::CompileInternal().

730  {
731 
732  Compiler compiler(request, nnet_);
733  // note: 'opts' only contains 'output_debug_info', which is true by default.
734  // There may be situations where we'd prefer not to keep it, for speed.
735  CompilerOptions opts;
736  NnetComputation *computation = new NnetComputation;
737 
738  {
739  Timer timer;
740  compiler.CreateComputation(opts, computation);
741  seconds_taken_compile_ += timer.Elapsed();
742  }
743 
744  int32 verbose_cutoff = 4;
745  if (GetVerboseLevel() >= verbose_cutoff) {
746  std::ostringstream os1;
747  request.Print(os1);
748  KALDI_LOG << "Computation request is " << os1.str();
749  std::ostringstream os2;
750  computation->Print(os2, nnet_);
751  KALDI_LOG << "Generated computation is: " << os2.str();
752  }
753 
754  { // some checking. Note: there may come a time when we might
755  // prefer to disable this checking.
756  Timer timer;
757  CheckComputationOptions check_config;
758  // we can do the rewrite check since it's before optimization.
759  check_config.check_rewrite = true;
760  ComputationChecker checker(check_config, nnet_, *computation);
761  checker.Check();
762  seconds_taken_check_ += timer.Elapsed();
763  }
764 
765  {
766  Timer timer;
768  MaxOutputTimeInRequest(request),
769  computation);
770  seconds_taken_optimize_ += timer.Elapsed();
771  }
772 
773  if (GetVerboseLevel() >= verbose_cutoff) {
774  std::ostringstream os;
775  computation->Print(os, nnet_);
776  KALDI_LOG << "Optimized computation is: " << os.str();
777  }
778 
779  { // check the computation again.
780  Timer timer;
781  CheckComputationOptions check_config;
782  ComputationChecker checker(check_config, nnet_, *computation);
783  checker.Check();
784  seconds_taken_check_ += timer.Elapsed();
785  }
786 
787  {
788  Timer timer;
789  computation->ComputeCudaIndexes();
790  seconds_taken_indexes_ += timer.Elapsed();
791  }
792  return computation;
793 }
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
int32 MaxOutputTimeInRequest(const ComputationRequest &request)
void Optimize(const NnetOptimizeOptions &config, const Nnet &nnet, int32 max_output_time_in_request, NnetComputation *computation)
This is the top-level function for optimizing a computation.
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
#define KALDI_LOG
Definition: kaldi-error.h:133
const NnetComputation * CompileViaShortcut ( const ComputationRequest request)
private

Definition at line 796 of file nnet-optimize.cc.

References kaldi::nnet3::CheckComputation(), CachingOptimizingCompiler::CompileInternal(), NnetComputation::ComputeCudaIndexes(), Timer::Elapsed(), kaldi::nnet3::ExpandComputation(), kaldi::GetVerboseLevel(), ComputationRequest::misc_info, CachingOptimizingCompiler::nnet_, kaldi::nnet2::NnetComputation(), kaldi::nnet3::RequestIsDecomposable(), CachingOptimizingCompiler::seconds_taken_expand_, and CachingOptimizingCompiler::seconds_taken_indexes_.

Referenced by CachingOptimizingCompiler::CompileInternal().

797  {
798  int32 num_n_values;
799  ComputationRequest mini_request;
800  if (!RequestIsDecomposable(request, &mini_request, &num_n_values))
801  return NULL;
802 
803  // By invoking CompileInternal() on the mini request, we go through the same
804  // caching process as for any externally requested computation.
805  std::shared_ptr<const NnetComputation> mini_computation =
806  CompileInternal(mini_request);
807 
808  // note: by default we always create debug_info, even in regular compilation.
809  // (e.g. it defaults to true in CompilerOptions). If it really seems to be a
810  // significant overhead, we can revisit this at some point in future.
811  bool need_debug_info = true;
812 
813 
814  NnetComputation *ans = new NnetComputation();
815 
816  {
817  Timer timer;
818  ExpandComputation(nnet_, request.misc_info, *mini_computation,
819  need_debug_info, num_n_values, ans);
820  seconds_taken_expand_ += timer.Elapsed();
821  }
822  if (GetVerboseLevel() >= 3) {
823  CheckComputation(nnet_, *ans, false);
824  }
825 
826  {
827  Timer timer;
828  ans->ComputeCudaIndexes();
829  seconds_taken_indexes_ += timer.Elapsed();
830  }
831  return ans;
832 }
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
bool RequestIsDecomposable(const ComputationRequest &request, ComputationRequest *mini_request, int32 *num_n_values)
This function, used in 'shortcut' compilation where we first compile a smaller computation with the s...
void ExpandComputation(const Nnet &nnet, const MiscComputationInfo &misc_info, const NnetComputation &computation, bool need_debug_info, int32 num_n_values, NnetComputation *expanded_computation)
This function is used in 'shortcut' compilation to expand a computation that has been compiled for ex...
void CheckComputation(const Nnet &nnet, const NnetComputation &computation, bool check_rewrite)
This is a convenience interface for class ComputationChecker.
void NnetComputation(const Nnet &nnet, const CuMatrixBase< BaseFloat > &input, bool pad_input, CuMatrixBase< BaseFloat > *output)
Does the basic neural net computation, on a sequence of data (e.g.
std::shared_ptr< const NnetComputation > CompileInternal(const ComputationRequest &request)
void ReadCache ( std::istream &  is,
bool  binary 
)

Definition at line 654 of file nnet-optimize.cc.

References CachingOptimizingCompiler::cache_, ComputationCache::Check(), Timer::Elapsed(), kaldi::GetVerboseLevel(), CachingOptimizingCompiler::nnet_, CachingOptimizingCompiler::opt_config_, NnetOptimizeOptions::Read(), ComputationCache::Read(), CachingOptimizingCompiler::seconds_taken_check_, CachingOptimizingCompiler::seconds_taken_io_, and CachingOptimizingCompiler::seconds_taken_total_.

Referenced by NnetChainTrainer::NnetChainTrainer(), NnetDiscriminativeTrainer::NnetDiscriminativeTrainer(), and NnetTrainer::NnetTrainer().

654  {
655  {
656  Timer timer;
657  NnetOptimizeOptions opt_config_cached;
658  opt_config_cached.Read(is, binary);
659  // we won't read cached computations if any optimize option has been changed.
660  if (!(opt_config_ == opt_config_cached))
661  return;
662  cache_.Read(is, binary);
663  seconds_taken_io_ += timer.Elapsed();
664  }
665  if (GetVerboseLevel() >= 2) {
666  Timer timer;
667  cache_.Check(nnet_);
668  seconds_taken_check_ += timer.Elapsed();
669  // we consider the check time part of the total time... this is very
670  // arbitrary but it only affects printed times-taken.
671  seconds_taken_total_ += timer.Elapsed();
672  }
673 
674 }
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
void Check(const Nnet &nnet) const
void Read(std::istream &is, bool binary)
void WriteCache ( std::ostream &  os,
bool  binary 
)

Member Data Documentation

Definition at line 279 of file nnet-optimize.h.

Referenced by CachingOptimizingCompiler::CompileInternal().

double seconds_taken_compile_
private
double seconds_taken_expand_
private
double seconds_taken_optimize_
private

The documentation for this class was generated from the following files: