All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
DiscTrainParallelClass Class Reference
Inheritance diagram for DiscTrainParallelClass:
Collaboration diagram for DiscTrainParallelClass:

Public Member Functions

 DiscTrainParallelClass (const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, bool store_separate_gradients, DiscriminativeExamplesRepository *repository, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
 DiscTrainParallelClass (const DiscTrainParallelClass &other)
void operator() ()
 ~DiscTrainParallelClass ()
- Public Member Functions inherited from MultiThreadable
virtual ~MultiThreadable ()

Private Attributes

const AmNnetam_nnet_
const TransitionModeltmodel_
bool store_separate_gradients_
NnetDiscriminativeStats stats_

Additional Inherited Members

- Public Attributes inherited from MultiThreadable
int32 thread_id_
int32 num_threads_

Detailed Description

Definition at line 103 of file

Constructor & Destructor Documentation

DiscTrainParallelClass ( const AmNnet am_nnet,
const TransitionModel tmodel,
const NnetDiscriminativeUpdateOptions opts,
bool  store_separate_gradients,
DiscriminativeExamplesRepository repository,
Nnet nnet_to_update,
NnetDiscriminativeStats stats 

Definition at line 123 of file

References DiscTrainParallelClass::nnet_to_update_, Nnet::SetZero(), and DiscTrainParallelClass::store_separate_gradients_.

123  :
124  MultiThreadable(other),
125  am_nnet_(other.am_nnet_), tmodel_(other.tmodel_), opts_(other.opts_),
126  store_separate_gradients_(other.store_separate_gradients_),
127  repository_(other.repository_), nnet_to_update_(other.nnet_to_update_),
128  nnet_to_update_orig_(other.nnet_to_update_orig_),
129  stats_ptr_(other.stats_ptr_) {
131  // To ensure correctness, we work on separate copies of the gradient
132  // object, which we'll sum at the end. This is used for exact gradient
133  // computation.
134  if (other.nnet_to_update_ != NULL) {
135  nnet_to_update_ = new Nnet(*(other.nnet_to_update_));
136  // our "nnet_to_update_" variable is a copy of the neural network
137  // we are to update (presumably a gradient). If we don't set these
138  // to zero we would end up adding multiple copies of the any initial
139  // gradient that "nnet_to_update_" contained when we initialize
140  // the first instance of the class.
141  nnet_to_update_->SetZero(true);
142  } else { // support case where we don't really need a gradient.
143  nnet_to_update_ = NULL;
144  }
145  }
146  }
void SetZero(bool treat_as_gradient)

Definition at line 164 of file

References NnetDiscriminativeStats::Add(), Nnet::AddNnet(), DiscTrainParallelClass::nnet_to_update_, DiscTrainParallelClass::nnet_to_update_orig_, DiscTrainParallelClass::stats_, and DiscTrainParallelClass::stats_ptr_.

164  {
166  // This branch is only taken if this instance of the class is
167  // one of the multiple instances allocated inside the RunMultiThreaded
168  // template function, *and* store_separate_gradients_ has been set to true.
169  // In the typical hogwild case, we don't do this.
171  delete nnet_to_update_;
172  }
174  }
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
void Add(const NnetDiscriminativeStats &other)

Member Function Documentation

void operator() ( )

Implements MultiThreadable.

Definition at line 148 of file

References DiscTrainParallelClass::am_nnet_, NnetDiscriminativeUpdateOptions::criterion, kaldi::GetVerboseLevel(), KALDI_VLOG, DiscTrainParallelClass::nnet_to_update_, kaldi::nnet2::NnetDiscriminativeUpdate(), DiscTrainParallelClass::opts_, NnetDiscriminativeStats::Print(), DiscriminativeExamplesRepository::ProvideExample(), DiscTrainParallelClass::repository_, DiscTrainParallelClass::stats_, MultiThreadable::thread_id_, and DiscTrainParallelClass::tmodel_.

148  {
149  DiscriminativeNnetExample *example;
150  while ((example = repository_->ProvideExample()) != NULL) {
151  // This is a function call to a function defined in
152  // nnet-compute-discriminative.h
154  *example, nnet_to_update_, &stats_);
155  delete example;
157  if (GetVerboseLevel() > 3) {
158  KALDI_VLOG(3) << "Printing local stats for thread " << thread_id_;
160  }
161  }
162  }
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
DiscriminativeNnetExample * ProvideExample()
This function is called by the code that does the training.
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
void NnetDiscriminativeUpdate(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
Does the neural net computation, lattice forward-backward, and backprop, for either the MMI...

Member Data Documentation

const AmNnet& am_nnet_
Nnet* nnet_to_update_orig_
bool store_separate_gradients_
const TransitionModel& tmodel_

The documentation for this class was generated from the following file: