All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
NnetExampleBackgroundReader Class Reference
Collaboration diagram for NnetExampleBackgroundReader:

Public Member Functions

 NnetExampleBackgroundReader (int32 minibatch_size, Nnet *nnet, SequentialNnetExampleReader *reader)
 
 ~NnetExampleBackgroundReader ()
 
void ReadExamples ()
 
bool GetNextMinibatch (std::vector< NnetExample > *examples, Matrix< BaseFloat > *formatted_examples, double *total_weight)
 

Static Public Member Functions

static void * Run (void *ptr_in)
 

Private Attributes

int32 minibatch_size_
 
Nnetnnet_
 
SequentialNnetExampleReaderreader_
 
std::thread thread_
 
std::vector< NnetExampleexamples_
 
Matrix< BaseFloatformatted_examples_
 
double total_weight_
 
Semaphore producer_semaphore_
 
Semaphore consumer_semaphore_
 
bool finished_
 

Detailed Description

Definition at line 27 of file train-nnet.cc.

Constructor & Destructor Documentation

NnetExampleBackgroundReader ( int32  minibatch_size,
Nnet nnet,
SequentialNnetExampleReader reader 
)
inline

Definition at line 29 of file train-nnet.cc.

References NnetExampleBackgroundReader::consumer_semaphore_, NnetExampleBackgroundReader::Run(), Semaphore::Signal(), and NnetExampleBackgroundReader::thread_.

31  :
32  minibatch_size_(minibatch_size), nnet_(nnet), reader_(reader),
33  finished_(false) {
34  // When this class is created, it spawns a thread which calls ReadExamples()
35  // in the background. Below, Run is the static class-member function.
36  thread_ = std::thread(Run, this);
37  // the following call is a signal that no-one is currently using the examples_ and
38  // formatted_examples_ class members.
40  }
static void * Run(void *ptr_in)
Definition: train-nnet.cc:94
void Signal()
increase the counter
SequentialNnetExampleReader * reader_
Definition: train-nnet.cc:131

Definition at line 41 of file train-nnet.cc.

References KALDI_ERR, and NnetExampleBackgroundReader::thread_.

41  {
42  if (!thread_.joinable())
43  KALDI_ERR << "No thread to join.";
44  thread_.join();
45  }
#define KALDI_ERR
Definition: kaldi-error.h:127

Member Function Documentation

bool GetNextMinibatch ( std::vector< NnetExample > *  examples,
Matrix< BaseFloat > *  formatted_examples,
double *  total_weight 
)
inline

Definition at line 104 of file train-nnet.cc.

References NnetExampleBackgroundReader::consumer_semaphore_, NnetExampleBackgroundReader::examples_, NnetExampleBackgroundReader::finished_, NnetExampleBackgroundReader::formatted_examples_, KALDI_ASSERT, NnetExampleBackgroundReader::producer_semaphore_, Semaphore::Signal(), Matrix< Real >::Swap(), NnetExampleBackgroundReader::total_weight_, and Semaphore::Wait().

Referenced by kaldi::nnet2::TrainNnetSimple().

106  {
108  // wait until examples_ and formatted_examples_ have been created by
109  // the background thread.
111  // the calls to swap and Swap are lightweight.
112  examples_.swap(*examples);
113  formatted_examples_.Swap(formatted_examples);
114  *total_weight = total_weight_;
115 
116  // signal the background thread that it is now free to write
117  // again to examples_ and formatted_examples_.
119 
120  if (examples->empty()) {
121  finished_ = true;
122  return false;
123  } else {
124  return true;
125  }
126  }
std::vector< NnetExample > examples_
Definition: train-nnet.cc:134
void Signal()
increase the counter
void Swap(Matrix< Real > *other)
Swaps the contents of *this and *other. Shallow swap.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Wait()
decrease the counter
void ReadExamples ( )
inline

Definition at line 49 of file train-nnet.cc.

References NnetExampleBackgroundReader::consumer_semaphore_, SequentialTableReader< Holder >::Done(), NnetExampleBackgroundReader::examples_, kaldi::nnet2::FormatNnetInput(), NnetExampleBackgroundReader::formatted_examples_, KALDI_ASSERT, NnetExampleBackgroundReader::minibatch_size_, SequentialTableReader< Holder >::Next(), NnetExampleBackgroundReader::nnet_, NnetExampleBackgroundReader::producer_semaphore_, NnetExampleBackgroundReader::reader_, Matrix< Real >::Resize(), Semaphore::Signal(), NnetExampleBackgroundReader::total_weight_, kaldi::nnet2::TotalNnetTrainingWeight(), SequentialTableReader< Holder >::Value(), and Semaphore::Wait().

Referenced by NnetExampleBackgroundReader::Run().

49  {
51  int32 minibatch_size = minibatch_size_;
52 
53 
54  // Loop over minibatches...
55  while (true) {
56  // When the following call succeeds we interpret it as a signal that
57  // we are free to write to the class-member variables examples_ and formatted_examples_.
59 
60  examples_.clear();
61  examples_.reserve(minibatch_size);
62  // Read the examples.
63  for (; examples_.size() < minibatch_size && !reader_->Done(); reader_->Next())
64  examples_.push_back(reader_->Value());
65 
66  // Format the examples as a single matrix. The reason we do this here is
67  // that it's a somewhat CPU-intensive operation (involves decompressing
68  // the matrix), so we do it in a separate thread from the one that's
69  // controlling the GPU (assuming we're using a GPU), so we can get better
70  // GPU utilization. If we have no GPU this doesn't hurt us.
71  if (examples_.empty()) {
73  total_weight_ = 0.0;
74  } else {
77  }
78 
79  bool finished = examples_.empty();
80 
81  // The following call alerts the main program thread (that calls
82  // GetNextMinibatch() that it can how use the contents of
83  // examples_ and formatted_examples_.
85 
86  // If we just read an empty minibatch (because no more examples),
87  // then return.
88  if (finished)
89  return;
90  }
91  }
std::vector< NnetExample > examples_
Definition: train-nnet.cc:134
void Signal()
increase the counter
void FormatNnetInput(const Nnet &nnet, const std::vector< NnetExample > &data, Matrix< BaseFloat > *input_mat)
Takes the input to the nnet for a minibatch of examples, and formats as a single matrix.
Definition: nnet-update.cc:207
SequentialNnetExampleReader * reader_
Definition: train-nnet.cc:131
BaseFloat TotalNnetTrainingWeight(const std::vector< NnetExample > &egs)
Returns the total weight summed over all the examples...
Definition: nnet-update.cc:248
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void Wait()
decrease the counter
static void* Run ( void *  ptr_in)
inlinestatic

Definition at line 94 of file train-nnet.cc.

References NnetExampleBackgroundReader::ReadExamples().

Referenced by NnetExampleBackgroundReader::NnetExampleBackgroundReader().

94  {
96  reinterpret_cast<NnetExampleBackgroundReader*>(ptr_in);
97  ptr->ReadExamples();
98  return NULL;
99  }
NnetExampleBackgroundReader(int32 minibatch_size, Nnet *nnet, SequentialNnetExampleReader *reader)
Definition: train-nnet.cc:29

Member Data Documentation

std::vector<NnetExample> examples_
private
bool finished_
private

Definition at line 142 of file train-nnet.cc.

Referenced by NnetExampleBackgroundReader::GetNextMinibatch().

Matrix<BaseFloat> formatted_examples_
private
int32 minibatch_size_
private

Definition at line 129 of file train-nnet.cc.

Referenced by NnetExampleBackgroundReader::ReadExamples().

Nnet* nnet_
private

Definition at line 130 of file train-nnet.cc.

Referenced by NnetExampleBackgroundReader::ReadExamples().

Semaphore producer_semaphore_
private
SequentialNnetExampleReader* reader_
private

Definition at line 131 of file train-nnet.cc.

Referenced by NnetExampleBackgroundReader::ReadExamples().

double total_weight_
private

The documentation for this class was generated from the following file: