110 bool store_separate_gradients,
112 Nnet *nnet_to_update,
114 am_nnet_(am_nnet), tmodel_(tmodel), opts_(opts),
115 store_separate_gradients_(store_separate_gradients),
116 repository_(repository),
117 nnet_to_update_(nnet_to_update),
118 nnet_to_update_orig_(nnet_to_update),
119 stats_ptr_(stats) { }
125 am_nnet_(other.am_nnet_), tmodel_(other.tmodel_), opts_(other.opts_),
126 store_separate_gradients_(other.store_separate_gradients_),
127 repository_(other.repository_), nnet_to_update_(other.nnet_to_update_),
128 nnet_to_update_orig_(other.nnet_to_update_orig_),
129 stats_ptr_(other.stats_ptr_) {
130 if (store_separate_gradients_) {
141 nnet_to_update_->SetZero(
true);
143 nnet_to_update_ = NULL;
148 void operator () () {
150 while ((example = repository_->ProvideExample()) != NULL) {
154 *example, nnet_to_update_, &stats_);
158 KALDI_VLOG(3) <<
"Printing local stats for thread " << thread_id_;
159 stats_.Print(opts_.criterion);
165 if (nnet_to_update_orig_ != nnet_to_update_) {
170 nnet_to_update_orig_->AddNnet(1.0, *nnet_to_update_);
171 delete nnet_to_update_;
173 stats_ptr_->Add(stats_);
195 Nnet *nnet_to_update,
200 const bool store_separate_gradients = (nnet_to_update != &(am_nnet.
GetNnet()));
203 store_separate_gradients,
204 &repository, nnet_to_update, stats);
211 for (; !example_reader->
Done(); example_reader->
Next()) {
KALDI_DISALLOW_COPY_AND_ASSIGN(DiscriminativeExamplesRepository)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void AcceptExample(const DiscriminativeNnetExample &example)
The following function is called by the code that reads in the examples.
DiscTrainParallelClass(const DiscTrainParallelClass &other)
This struct stores neural net training examples to be used in multi-threaded training.
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
Semaphore empty_semaphore_
void Signal()
increase the counter
Nnet * nnet_to_update_orig_
DiscriminativeExamplesRepository * repository_
bool store_separate_gradients_
void ExamplesDone()
The following function is called by the code that reads in the examples, when we're done reading exam...
NnetDiscriminativeStats stats_
Semaphore full_semaphore_
const NnetDiscriminativeUpdateOptions & opts_
~DiscTrainParallelClass()
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
DiscriminativeExamplesRepository()
void Print(std::string criterion)
DiscriminativeNnetExample * ProvideExample()
This function is called by the code that does the training.
const TransitionModel & tmodel_
This struct is used to store the information we need for discriminative training (MMI or MPE)...
DiscTrainParallelClass(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, bool store_separate_gradients, DiscriminativeExamplesRepository *repository, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
#define KALDI_ASSERT(cond)
std::mutex examples_mutex_
void NnetDiscriminativeUpdateParallel(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, int32 num_threads, SequentialDiscriminativeNnetExampleReader *example_reader, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
NnetDiscriminativeStats * stats_ptr_
std::deque< DiscriminativeNnetExample * > examples_
void Wait()
decrease the counter
const Nnet & GetNnet() const
void NnetDiscriminativeUpdate(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
Does the neural net computation, lattice forward-backward, and backprop, for either the MMI...