train-nnet.cc
Go to the documentation of this file.
1 // nnet2/train-nnet.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet2/train-nnet.h"
21 #include "util/kaldi-thread.h"
22 
23 namespace kaldi {
24 namespace nnet2 {
25 
26 
28  public:
30  Nnet *nnet,
32  minibatch_size_(minibatch_size), nnet_(nnet), reader_(reader),
33  finished_(false) {
34  // When this class is created, it spawns a thread which calls ReadExamples()
35  // in the background. Below, Run is the static class-member function.
36  thread_ = std::thread(Run, this);
37  // the following call is a signal that no-one is currently using the examples_ and
38  // formatted_examples_ class members.
40  }
42  if (!thread_.joinable())
43  KALDI_ERR << "No thread to join.";
44  thread_.join();
45  }
46 
47  // This will be called in a background thread. It's responsible for
48  // reading and formatting the examples.
49  void ReadExamples() {
51  int32 minibatch_size = minibatch_size_;
52 
53 
54  // Loop over minibatches...
55  while (true) {
56  // When the following call succeeds we interpret it as a signal that
57  // we are free to write to the class-member variables examples_ and formatted_examples_.
59 
60  examples_.clear();
61  examples_.reserve(minibatch_size);
62  // Read the examples.
63  for (; examples_.size() < minibatch_size && !reader_->Done(); reader_->Next())
64  examples_.push_back(reader_->Value());
65 
66  // Format the examples as a single matrix. The reason we do this here is
67  // that it's a somewhat CPU-intensive operation (involves decompressing
68  // the matrix), so we do it in a separate thread from the one that's
69  // controlling the GPU (assuming we're using a GPU), so we can get better
70  // GPU utilization. If we have no GPU this doesn't hurt us.
71  if (examples_.empty()) {
73  total_weight_ = 0.0;
74  } else {
77  }
78 
79  bool finished = examples_.empty();
80 
81  // The following call alerts the main program thread (that calls
82  // GetNextMinibatch() that it can how use the contents of
83  // examples_ and formatted_examples_.
85 
86  // If we just read an empty minibatch (because no more examples),
87  // then return.
88  if (finished)
89  return;
90  }
91  }
92 
93  // this wrapper can be passed to pthread_create.
94  static void* Run(void *ptr_in) {
96  reinterpret_cast<NnetExampleBackgroundReader*>(ptr_in);
97  ptr->ReadExamples();
98  return NULL;
99  }
100 
101  // This call makes available the next minibatch of input. It returns
102  // true if it got some, and false if there was no more available.
103  // It is an error if you call this function after it has returned false.
104  bool GetNextMinibatch(std::vector<NnetExample> *examples,
105  Matrix<BaseFloat> *formatted_examples,
106  double *total_weight) {
108  // wait until examples_ and formatted_examples_ have been created by
109  // the background thread.
111  // the calls to swap and Swap are lightweight.
112  examples_.swap(*examples);
113  formatted_examples_.Swap(formatted_examples);
114  *total_weight = total_weight_;
115 
116  // signal the background thread that it is now free to write
117  // again to examples_ and formatted_examples_.
119 
120  if (examples->empty()) {
121  finished_ = true;
122  return false;
123  } else {
124  return true;
125  }
126  }
127 
128  private:
132  std::thread thread_;
133 
134  std::vector<NnetExample> examples_;
136  double total_weight_; // total weight, from TotalNnetTrainingWeight(examples_).
137  // better to compute this in the background thread.
138 
141 
142  bool finished_;
143 };
144 
145 
146 
148  Nnet *nnet,
150  double *tot_weight_ptr,
151  double *tot_logprob_ptr) {
152  int64 num_egs_processed = 0;
153  double tot_weight = 0.0, tot_logprob = 0.0;
154  NnetExampleBackgroundReader background_reader(config.minibatch_size,
155  nnet, reader);
157  while (true) {
158  // Iterate over phases. A phase of training is just a certain number of
159  // minibatches, and its only significance is that it's the periodicity with
160  // which we print diagnostics.
161  double tot_weight_this_phase = 0.0, tot_logprob_this_phase = 0.0;
162 
163  int32 i;
164  for (i = 0; i < config.minibatches_per_phase; i++) {
165  std::vector<NnetExample> examples;
166  Matrix<BaseFloat> examples_formatted;
167  double minibatch_total_weight; // this will normally equal minibatch size.
168  if (!background_reader.GetNextMinibatch(&examples, &examples_formatted,
169  &minibatch_total_weight))
170  break;
171  tot_logprob_this_phase += DoBackprop(*nnet, examples, &examples_formatted,
172  nnet, NULL);
173  tot_weight_this_phase += minibatch_total_weight;
174  num_egs_processed += examples.size();
175  }
176  if (i != 0) {
177  KALDI_LOG << "Training objective function (this phase) is "
178  << (tot_logprob_this_phase / tot_weight_this_phase) << " over "
179  << tot_weight_this_phase << " frames.";
180  }
181  tot_weight += tot_weight_this_phase;
182  tot_logprob += tot_logprob_this_phase;
183  if (i != config.minibatches_per_phase) {
184  // did not get all the minibatches we wanted because no more input.
185  // this is true if and only if we did "break" in the loop over i above.
186  break;
187  }
188  }
189  if (tot_weight == 0.0) {
190  KALDI_WARN << "No data seen.";
191  } else {
192  KALDI_LOG << "Did backprop on " << tot_weight
193  << " examples, average log-prob per frame is "
194  << (tot_logprob / tot_weight);
195  KALDI_LOG << "[this line is to be parsed by a script:] log-prob-per-frame="
196  << (tot_logprob / tot_weight);
197  }
198  if (tot_weight_ptr) *tot_weight_ptr = tot_weight;
199  if (tot_logprob_ptr) *tot_logprob_ptr = tot_logprob;
200  return num_egs_processed;
201 }
202 
203 
204 
205 } // namespace nnet2
206 } // namespace kaldi
bool GetNextMinibatch(std::vector< NnetExample > *examples, Matrix< BaseFloat > *formatted_examples, double *total_weight)
Definition: train-nnet.cc:104
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
std::vector< NnetExample > examples_
Definition: train-nnet.cc:134
static void * Run(void *ptr_in)
Definition: train-nnet.cc:94
void Signal()
increase the counter
kaldi::int32 int32
void Swap(Matrix< Real > *other)
Swaps the contents of *this and *other. Shallow swap.
double DoBackprop(const Nnet &nnet, const std::vector< NnetExample > &examples, Nnet *nnet_to_update, double *tot_accuracy)
This function computes the objective function and either updates the model or adds to parameter gradi...
Definition: nnet-update.cc:265
void FormatNnetInput(const Nnet &nnet, const std::vector< NnetExample > &data, Matrix< BaseFloat > *input_mat)
Takes the input to the nnet for a minibatch of examples, and formats as a single matrix.
Definition: nnet-update.cc:207
int64 TrainNnetSimple(const NnetSimpleTrainerConfig &config, Nnet *nnet, SequentialNnetExampleReader *reader, double *tot_weight_ptr, double *tot_logprob_ptr)
Train on all the examples it can read from the reader.
Definition: train-nnet.cc:147
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
SequentialNnetExampleReader * reader_
Definition: train-nnet.cc:131
BaseFloat TotalNnetTrainingWeight(const std::vector< NnetExample > &egs)
Returns the total weight summed over all the examples...
Definition: nnet-update.cc:248
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
#define KALDI_LOG
Definition: kaldi-error.h:153
void Wait()
decrease the counter
NnetExampleBackgroundReader(int32 minibatch_size, Nnet *nnet, SequentialNnetExampleReader *reader)
Definition: train-nnet.cc:29