DoBackpropParallelClass Class Reference
Inheritance diagram for DoBackpropParallelClass:
Collaboration diagram for DoBackpropParallelClass:

Public Member Functions

 DoBackpropParallelClass (const Nnet &nnet, ExamplesRepository *repository, double *tot_weight_ptr, double *log_prob_ptr, Nnet *nnet_to_update, bool store_separate_gradients)
 
 DoBackpropParallelClass (const DoBackpropParallelClass &other)
 
void operator() ()
 
 ~DoBackpropParallelClass ()
 
- Public Member Functions inherited from MultiThreadable
virtual ~MultiThreadable ()
 

Private Attributes

const Nnetnnet_
 
ExamplesRepositoryrepository_
 
Nnetnnet_to_update_
 
Nnetnnet_to_update_orig_
 
bool store_separate_gradients_
 
double * tot_weight_ptr_
 
double * log_prob_ptr_
 
double tot_weight_
 
double log_prob_
 

Additional Inherited Members

- Public Attributes inherited from MultiThreadable
int32 thread_id_
 
int32 num_threads_
 

Detailed Description

Definition at line 29 of file nnet-update-parallel.cc.

Constructor & Destructor Documentation

◆ DoBackpropParallelClass() [1/2]

DoBackpropParallelClass ( const Nnet nnet,
ExamplesRepository repository,
double *  tot_weight_ptr,
double *  log_prob_ptr,
Nnet nnet_to_update,
bool  store_separate_gradients 
)
inline

◆ DoBackpropParallelClass() [2/2]

Definition at line 50 of file nnet-update-parallel.cc.

References DoBackpropParallelClass::nnet_to_update_, Nnet::SetZero(), and DoBackpropParallelClass::store_separate_gradients_.

50  :
51  MultiThreadable(other),
52  nnet_(other.nnet_),
53  repository_(other.repository_),
54  nnet_to_update_(other.nnet_to_update_),
55  nnet_to_update_orig_(other.nnet_to_update_orig_),
56  store_separate_gradients_(other.store_separate_gradients_),
57  tot_weight_ptr_(other.tot_weight_ptr_),
58  log_prob_ptr_(other.log_prob_ptr_),
59  tot_weight_(0),
60  log_prob_(0.0) {
62  // To ensure correctness, we work on separate copies of the gradient
63  // object, which we'll sum at the end. This is used for exact gradient
64  // computation.
65  if (other.nnet_to_update_ != NULL) {
66  nnet_to_update_ = new Nnet(*(other.nnet_to_update_));
67  // our "nnet_to_update_" variable is a copy of the neural network
68  // we are to update (presumably a gradient). If we don't set these
69  // to zero we would end up adding multiple copies of the any initial
70  // gradient that "nnet_to_update_" contained when we initialize
71  // the first instance of the class.
72  nnet_to_update_->SetZero(true);
73  } else { // support case where we don't really need a gradient.
74  nnet_to_update_ = NULL;
75  }
76  }
77  }
void SetZero(bool treat_as_gradient)
Definition: nnet-nnet.cc:151

◆ ~DoBackpropParallelClass()

Definition at line 98 of file nnet-update-parallel.cc.

References Nnet::AddNnet(), DoBackpropParallelClass::log_prob_, DoBackpropParallelClass::log_prob_ptr_, DoBackpropParallelClass::nnet_to_update_, DoBackpropParallelClass::nnet_to_update_orig_, DoBackpropParallelClass::tot_weight_, and DoBackpropParallelClass::tot_weight_ptr_.

98  {
100  // This branch is only taken if this instance of the class is
101  // one of the multiple instances allocated inside the RunMultiThreaded
102  // template function, *and* store_separate_gradients_ has been set to true.
103  // In the typical hogwild case, we don't do this.
105  delete nnet_to_update_;
106  }
109  }
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
Definition: nnet-nnet.cc:576

Member Function Documentation

◆ operator()()

void operator() ( )
inlinevirtual

Implements MultiThreadable.

Definition at line 79 of file nnet-update-parallel.cc.

References kaldi::nnet2::ComputeNnetObjf(), kaldi::nnet2::DoBackprop(), KALDI_VLOG, DoBackpropParallelClass::log_prob_, DoBackpropParallelClass::nnet_, DoBackpropParallelClass::nnet_to_update_, ExamplesRepository::ProvideExamples(), DoBackpropParallelClass::repository_, MultiThreadable::thread_id_, DoBackpropParallelClass::tot_weight_, and kaldi::nnet2::TotalNnetTrainingWeight().

79  {
80  std::vector<NnetExample> examples;
81  while (repository_->ProvideExamples(&examples)) {
82  // This is a function call to a function defined in
83  // nnet-update.h
84  double tot_loglike;
85  if (nnet_to_update_ != NULL)
86  tot_loglike = DoBackprop(nnet_, examples, nnet_to_update_);
87  else
88  tot_loglike = ComputeNnetObjf(nnet_, examples);
90  log_prob_ += tot_loglike;
91  KALDI_VLOG(4) << "Thread " << thread_id_ << " saw "
92  << tot_weight_ << " frames so far (weighted); likelihood "
93  << "per frame so far is " << (log_prob_ / tot_weight_);
94  examples.clear();
95  }
96  }
double ComputeNnetObjf(const Nnet &nnet, const std::vector< NnetExample > &examples, double *tot_accuracy)
Computes objective function over a minibatch.
Definition: nnet-update.cc:258
double DoBackprop(const Nnet &nnet, const std::vector< NnetExample > &examples, Nnet *nnet_to_update, double *tot_accuracy)
This function computes the objective function and either updates the model or adds to parameter gradi...
Definition: nnet-update.cc:265
bool ProvideExamples(std::vector< NnetExample > *examples)
This function is called by the code that does the training.
BaseFloat TotalNnetTrainingWeight(const std::vector< NnetExample > &egs)
Returns the total weight summed over all the examples...
Definition: nnet-update.cc:248
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156

Member Data Documentation

◆ log_prob_

double log_prob_
private

◆ log_prob_ptr_

double* log_prob_ptr_
private

◆ nnet_

const Nnet& nnet_
private

Definition at line 111 of file nnet-update-parallel.cc.

Referenced by DoBackpropParallelClass::operator()().

◆ nnet_to_update_

◆ nnet_to_update_orig_

Nnet* nnet_to_update_orig_
private

◆ repository_

ExamplesRepository* repository_
private

Definition at line 112 of file nnet-update-parallel.cc.

Referenced by DoBackpropParallelClass::operator()().

◆ store_separate_gradients_

bool store_separate_gradients_
private

◆ tot_weight_

double tot_weight_
private

◆ tot_weight_ptr_

double* tot_weight_ptr_
private

The documentation for this class was generated from the following file: