DiscTrainParallelClass Class Reference
Inheritance diagram for DiscTrainParallelClass:
Collaboration diagram for DiscTrainParallelClass:

Public Member Functions

 DiscTrainParallelClass (const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, bool store_separate_gradients, DiscriminativeExamplesRepository *repository, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
 
 DiscTrainParallelClass (const DiscTrainParallelClass &other)
 
void operator() ()
 
 ~DiscTrainParallelClass ()
 
- Public Member Functions inherited from MultiThreadable
virtual ~MultiThreadable ()
 

Private Attributes

const AmNnetam_nnet_
 
const TransitionModeltmodel_
 
const NnetDiscriminativeUpdateOptionsopts_
 
bool store_separate_gradients_
 
DiscriminativeExamplesRepositoryrepository_
 
Nnetnnet_to_update_
 
Nnetnnet_to_update_orig_
 
NnetDiscriminativeStatsstats_ptr_
 
NnetDiscriminativeStats stats_
 

Additional Inherited Members

- Public Attributes inherited from MultiThreadable
int32 thread_id_
 
int32 num_threads_
 

Detailed Description

Definition at line 103 of file nnet-compute-discriminative-parallel.cc.

Constructor & Destructor Documentation

◆ DiscTrainParallelClass() [1/2]

DiscTrainParallelClass ( const AmNnet am_nnet,
const TransitionModel tmodel,
const NnetDiscriminativeUpdateOptions opts,
bool  store_separate_gradients,
DiscriminativeExamplesRepository repository,
Nnet nnet_to_update,
NnetDiscriminativeStats stats 
)
inline

◆ DiscTrainParallelClass() [2/2]

Definition at line 123 of file nnet-compute-discriminative-parallel.cc.

References DiscTrainParallelClass::nnet_to_update_.

123  :
124  MultiThreadable(other),
125  am_nnet_(other.am_nnet_), tmodel_(other.tmodel_), opts_(other.opts_),
126  store_separate_gradients_(other.store_separate_gradients_),
127  repository_(other.repository_), nnet_to_update_(other.nnet_to_update_),
128  nnet_to_update_orig_(other.nnet_to_update_orig_),
129  stats_ptr_(other.stats_ptr_) {
131  // To ensure correctness, we work on separate copies of the gradient
132  // object, which we'll sum at the end. This is used for exact gradient
133  // computation.
134  if (other.nnet_to_update_ != NULL) {
135  nnet_to_update_ = new Nnet(*(other.nnet_to_update_));
136  // our "nnet_to_update_" variable is a copy of the neural network
137  // we are to update (presumably a gradient). If we don't set these
138  // to zero we would end up adding multiple copies of the any initial
139  // gradient that "nnet_to_update_" contained when we initialize
140  // the first instance of the class.
141  nnet_to_update_->SetZero(true);
142  } else { // support case where we don't really need a gradient.
143  nnet_to_update_ = NULL;
144  }
145  }
146  }
void SetZero(bool treat_as_gradient)
Definition: nnet-nnet.cc:151

◆ ~DiscTrainParallelClass()

Definition at line 164 of file nnet-compute-discriminative-parallel.cc.

164  {
166  // This branch is only taken if this instance of the class is
167  // one of the multiple instances allocated inside the RunMultiThreaded
168  // template function, *and* store_separate_gradients_ has been set to true.
169  // In the typical hogwild case, we don't do this.
171  delete nnet_to_update_;
172  }
174  }
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
Definition: nnet-nnet.cc:576
void Add(const NnetDiscriminativeStats &other)

Member Function Documentation

◆ operator()()

void operator() ( )
inlinevirtual

Implements MultiThreadable.

Definition at line 148 of file nnet-compute-discriminative-parallel.cc.

References kaldi::GetVerboseLevel(), KALDI_VLOG, and kaldi::nnet2::NnetDiscriminativeUpdate().

148  {
149  DiscriminativeNnetExample *example;
150  while ((example = repository_->ProvideExample()) != NULL) {
151  // This is a function call to a function defined in
152  // nnet-compute-discriminative.h
154  *example, nnet_to_update_, &stats_);
155  delete example;
156 
157  if (GetVerboseLevel() > 3) {
158  KALDI_VLOG(3) << "Printing local stats for thread " << thread_id_;
160  }
161  }
162  }
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
DiscriminativeNnetExample * ProvideExample()
This function is called by the code that does the training.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void NnetDiscriminativeUpdate(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
Does the neural net computation, lattice forward-backward, and backprop, for either the MMI...

Member Data Documentation

◆ am_nnet_

const AmNnet& am_nnet_
private

Definition at line 176 of file nnet-compute-discriminative-parallel.cc.

◆ nnet_to_update_

Nnet* nnet_to_update_
private

◆ nnet_to_update_orig_

Nnet* nnet_to_update_orig_
private

Definition at line 182 of file nnet-compute-discriminative-parallel.cc.

◆ opts_

const NnetDiscriminativeUpdateOptions& opts_
private

Definition at line 178 of file nnet-compute-discriminative-parallel.cc.

◆ repository_

DiscriminativeExamplesRepository* repository_
private

Definition at line 180 of file nnet-compute-discriminative-parallel.cc.

◆ stats_

NnetDiscriminativeStats stats_
private

Definition at line 184 of file nnet-compute-discriminative-parallel.cc.

◆ stats_ptr_

NnetDiscriminativeStats* stats_ptr_
private

Definition at line 183 of file nnet-compute-discriminative-parallel.cc.

◆ store_separate_gradients_

bool store_separate_gradients_
private

Definition at line 179 of file nnet-compute-discriminative-parallel.cc.

◆ tmodel_

const TransitionModel& tmodel_
private

Definition at line 177 of file nnet-compute-discriminative-parallel.cc.


The documentation for this class was generated from the following file: