nnet-compute-discriminative-parallel.cc
Go to the documentation of this file.
1 // nnet2/nnet-compute-discriminative-parallel.cc
2 
3 // Copyright 2012-2013 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <deque>
21 #include <mutex>
23 #include "hmm/posterior.h"
24 #include "lat/lattice-functions.h"
25 #include "util/kaldi-semaphore.h"
26 #include "util/kaldi-thread.h"
27 
28 namespace kaldi {
29 namespace nnet2 {
30 
34  public:
36  void AcceptExample(const DiscriminativeNnetExample &example);
37 
41  void ExamplesDone();
42 
48 
51  done_(false) { }
52  private:
56  std::mutex examples_mutex_; // mutex we lock to modify examples_.
57 
58  std::deque<DiscriminativeNnetExample*> examples_;
59  bool done_;
61 };
62 
63 
65  const DiscriminativeNnetExample &example) {
67  examples_mutex_.lock();
68  examples_.push_back(new DiscriminativeNnetExample(example));
69  examples_mutex_.unlock();
71 }
72 
74  for (int32 i = 0; i < buffer_size_; i++)
76  examples_mutex_.lock();
77  KALDI_ASSERT(examples_.empty());
78  examples_mutex_.unlock();
79  done_ = true;
81 }
82 
86  if (done_) {
87  KALDI_ASSERT(examples_.empty());
88  full_semaphore_.Signal(); // Increment the semaphore so
89  // the call by the next thread will not block.
90  return NULL; // no examples to return-- all finished.
91  } else {
92  examples_mutex_.lock();
93  KALDI_ASSERT(!examples_.empty());
94  DiscriminativeNnetExample *ans = examples_.front();
95  examples_.pop_front();
96  examples_mutex_.unlock();
98  return ans;
99  }
100 }
101 
102 
104  public:
105  // This constructor is only called for a temporary object
106  // that we pass to the RunMultiThreaded function.
108  const TransitionModel &tmodel,
110  bool store_separate_gradients,
112  Nnet *nnet_to_update,
113  NnetDiscriminativeStats *stats):
114  am_nnet_(am_nnet), tmodel_(tmodel), opts_(opts),
115  store_separate_gradients_(store_separate_gradients),
116  repository_(repository),
117  nnet_to_update_(nnet_to_update),
118  nnet_to_update_orig_(nnet_to_update),
119  stats_ptr_(stats) { }
120 
121  // The following constructor is called multiple times within
122  // the RunMultiThreaded template function.
124  MultiThreadable(other),
125  am_nnet_(other.am_nnet_), tmodel_(other.tmodel_), opts_(other.opts_),
126  store_separate_gradients_(other.store_separate_gradients_),
127  repository_(other.repository_), nnet_to_update_(other.nnet_to_update_),
128  nnet_to_update_orig_(other.nnet_to_update_orig_),
129  stats_ptr_(other.stats_ptr_) {
130  if (store_separate_gradients_) {
131  // To ensure correctness, we work on separate copies of the gradient
132  // object, which we'll sum at the end. This is used for exact gradient
133  // computation.
134  if (other.nnet_to_update_ != NULL) {
135  nnet_to_update_ = new Nnet(*(other.nnet_to_update_));
136  // our "nnet_to_update_" variable is a copy of the neural network
137  // we are to update (presumably a gradient). If we don't set these
138  // to zero we would end up adding multiple copies of the any initial
139  // gradient that "nnet_to_update_" contained when we initialize
140  // the first instance of the class.
141  nnet_to_update_->SetZero(true);
142  } else { // support case where we don't really need a gradient.
143  nnet_to_update_ = NULL;
144  }
145  }
146  }
147  // This does the main function of the class.
148  void operator () () {
149  DiscriminativeNnetExample *example;
150  while ((example = repository_->ProvideExample()) != NULL) {
151  // This is a function call to a function defined in
152  // nnet-compute-discriminative.h
153  NnetDiscriminativeUpdate(am_nnet_, tmodel_, opts_,
154  *example, nnet_to_update_, &stats_);
155  delete example;
156 
157  if (GetVerboseLevel() > 3) {
158  KALDI_VLOG(3) << "Printing local stats for thread " << thread_id_;
159  stats_.Print(opts_.criterion);
160  }
161  }
162  }
163 
165  if (nnet_to_update_orig_ != nnet_to_update_) {
166  // This branch is only taken if this instance of the class is
167  // one of the multiple instances allocated inside the RunMultiThreaded
168  // template function, *and* store_separate_gradients_ has been set to true.
169  // In the typical hogwild case, we don't do this.
170  nnet_to_update_orig_->AddNnet(1.0, *nnet_to_update_);
171  delete nnet_to_update_;
172  }
173  stats_ptr_->Add(stats_);
174  }
175  private:
176  const AmNnet &am_nnet_;
185 };
186 
187 
188 
190  const AmNnet &am_nnet,
191  const TransitionModel &tmodel,
193  int32 num_threads,
195  Nnet *nnet_to_update,
196  NnetDiscriminativeStats *stats) {
197 
199 
200  const bool store_separate_gradients = (nnet_to_update != &(am_nnet.GetNnet()));
201 
202  DiscTrainParallelClass c(am_nnet, tmodel, opts,
203  store_separate_gradients,
204  &repository, nnet_to_update, stats);
205 
206  {
207  // The initialization of the following class spawns the threads that
208  // process the examples. They get re-joined in its destructor.
209  MultiThreader<DiscTrainParallelClass> m(num_threads, c);
210 
211  for (; !example_reader->Done(); example_reader->Next()) {
212  repository.AcceptExample(example_reader->Value());
213  }
214  repository.ExamplesDone();
215  }
216  stats->Print(opts.criterion);
217 }
218 
219 
220 
221 } // namespace nnet2
222 } // namespace kaldi
KALDI_DISALLOW_COPY_AND_ASSIGN(DiscriminativeExamplesRepository)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AcceptExample(const DiscriminativeNnetExample &example)
The following function is called by the code that reads in the examples.
DiscTrainParallelClass(const DiscTrainParallelClass &other)
This struct stores neural net training examples to be used in multi-threaded training.
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void Signal()
increase the counter
kaldi::int32 int32
void ExamplesDone()
The following function is called by the code that reads in the examples, when we&#39;re done reading exam...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
DiscriminativeNnetExample * ProvideExample()
This function is called by the code that does the training.
This struct is used to store the information we need for discriminative training (MMI or MPE)...
Definition: nnet-example.h:136
DiscTrainParallelClass(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, bool store_separate_gradients, DiscriminativeExamplesRepository *repository, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void NnetDiscriminativeUpdateParallel(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, int32 num_threads, SequentialDiscriminativeNnetExampleReader *example_reader, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
void Wait()
decrease the counter
const Nnet & GetNnet() const
Definition: am-nnet.h:61
void NnetDiscriminativeUpdate(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
Does the neural net computation, lattice forward-backward, and backprop, for either the MMI...