nnet-example.cc
Go to the documentation of this file.
1 // nnet2/nnet-example.cc
2 
3 // Copyright 2012-2013 Johns Hopkins University (author: Daniel Povey)
4 // 2014 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "nnet2/nnet-example.h"
22 #include "lat/lattice-functions.h"
23 #include "hmm/posterior.h"
24 
25 namespace kaldi {
26 namespace nnet2 {
27 
28 // This function returns true if the example has labels which, for each frame,
29 // have a single element with probability one; and if so, it outputs them to the
30 // vector in the associated pointer. This enables us to write the egs more
31 // compactly to disk in this common case.
33  const NnetExample &eg,
34  std::vector<int32> *simple_labels) {
35  size_t num_frames = eg.labels.size();
36  for (int32 t = 0; t < num_frames; t++)
37  if (eg.labels[t].size() != 1 || eg.labels[t][0].second != 1.0)
38  return false;
39  simple_labels->resize(num_frames);
40  for (int32 t = 0; t < num_frames; t++)
41  (*simple_labels)[t] = eg.labels[t][0].first;
42  return true;
43 }
44 
45 
46 void NnetExample::Write(std::ostream &os, bool binary) const {
47  // Note: weight, label, input_frames and spk_info are members. This is a
48  // struct.
49  WriteToken(os, binary, "<NnetExample>");
50 
51  // At this point, we write <Lab1> if we have "simple" labels, or
52  // <Lab2> in general. Previous code (when we had only one frame of
53  // labels) just wrote <Labels>.
54  std::vector<int32> simple_labels;
55  if (HasSimpleLabels(*this, &simple_labels)) {
56  WriteToken(os, binary, "<Lab1>");
57  WriteIntegerVector(os, binary, simple_labels);
58  } else {
59  WriteToken(os, binary, "<Lab2>");
60  int32 num_frames = labels.size();
61  WriteBasicType(os, binary, num_frames);
62  for (int32 t = 0; t < num_frames; t++) {
63  int32 size = labels[t].size();
64  WriteBasicType(os, binary, size);
65  for (int32 i = 0; i < size; i++) {
66  WriteBasicType(os, binary, labels[t][i].first);
67  WriteBasicType(os, binary, labels[t][i].second);
68  }
69  }
70  }
71  WriteToken(os, binary, "<InputFrames>");
72  input_frames.Write(os, binary);
73  WriteToken(os, binary, "<LeftContext>");
74  WriteBasicType(os, binary, left_context);
75  WriteToken(os, binary, "<SpkInfo>");
76  spk_info.Write(os, binary);
77  WriteToken(os, binary, "</NnetExample>");
78 }
79 
80 void NnetExample::Read(std::istream &is, bool binary) {
81  // Note: weight, label, input_frames, left_context and spk_info are members.
82  // This is a struct.
83  ExpectToken(is, binary, "<NnetExample>");
84 
85  std::string token;
86  ReadToken(is, binary, &token);
87  if (!strcmp(token.c_str(), "<Lab1>")) { // simple label format
88  std::vector<int32> simple_labels;
89  ReadIntegerVector(is, binary, &simple_labels);
90  labels.resize(simple_labels.size());
91  for (size_t i = 0; i < simple_labels.size(); i++) {
92  labels[i].resize(1);
93  labels[i][0].first = simple_labels[i];
94  labels[i][0].second = 1.0;
95  }
96  } else if (!strcmp(token.c_str(), "<Lab2>")) { // generic label format
97  int32 num_frames;
98  ReadBasicType(is, binary, &num_frames);
99  KALDI_ASSERT(num_frames > 0);
100  labels.resize(num_frames);
101  for (int32 t = 0; t < num_frames; t++) {
102  int32 size;
103  ReadBasicType(is, binary, &size);
104  KALDI_ASSERT(size >= 0);
105  labels[t].resize(size);
106  for (int32 i = 0; i < size; i++) {
107  ReadBasicType(is, binary, &(labels[t][i].first));
108  ReadBasicType(is, binary, &(labels[t][i].second));
109  }
110  }
111  } else if (!strcmp(token.c_str(), "<Labels>")) { // back-compatibility
112  labels.resize(1); // old format had 1 frame of labels.
113  int32 size;
114  ReadBasicType(is, binary, &size);
115  labels[0].resize(size);
116  for (int32 i = 0; i < size; i++) {
117  ReadBasicType(is, binary, &(labels[0][i].first));
118  ReadBasicType(is, binary, &(labels[0][i].second));
119  }
120  } else {
121  KALDI_ERR << "Expected token <Lab1>, <Lab2> or <Labels>, got " << token;
122  }
123  ExpectToken(is, binary, "<InputFrames>");
124  input_frames.Read(is, binary);
125  ExpectToken(is, binary, "<LeftContext>"); // Note: this member is
126  // recently added, but I don't think we'll get too much back-compatibility
127  // problems from not handling the old format.
128  ReadBasicType(is, binary, &left_context);
129  ExpectToken(is, binary, "<SpkInfo>");
130  spk_info.Read(is, binary);
131  ExpectToken(is, binary, "</NnetExample>");
132 }
133 
134 void NnetExample::SetLabelSingle(int32 frame, int32 pdf_id, BaseFloat weight) {
135  KALDI_ASSERT(static_cast<size_t>(frame) < labels.size());
136  labels[frame].clear();
137  labels[frame].push_back(std::make_pair(pdf_id, weight));
138 }
139 
141  BaseFloat max = -1.0;
142  int32 pdf_id = -1;
143  KALDI_ASSERT(static_cast<size_t>(frame) < labels.size());
144  for (int32 i = 0; i < labels[frame].size(); i++) {
145  if (labels[frame][i].second > max) {
146  pdf_id = labels[frame][i].first;
147  max = labels[frame][i].second;
148  }
149  }
150  if (weight != NULL) *weight = max;
151  return pdf_id;
152 }
153 
154 
155 
157 
158 // Self-constructor that can reduce the number of frames and/or context.
160  int32 start_frame,
161  int32 new_num_frames,
162  int32 new_left_context,
163  int32 new_right_context): spk_info(input.spk_info) {
164  int32 num_label_frames = input.labels.size();
165  if (start_frame < 0) start_frame = 0; // start_frame is offset in the labeled
166  // frames.
167  KALDI_ASSERT(start_frame < num_label_frames);
168  if (start_frame + new_num_frames > num_label_frames || new_num_frames == -1)
169  new_num_frames = num_label_frames - start_frame;
170  // compute right-context of input.
171  int32 input_right_context =
172  input.input_frames.NumRows() - input.left_context - num_label_frames;
173  if (new_left_context == -1) new_left_context = input.left_context;
174  if (new_right_context == -1) new_right_context = input_right_context;
175  if (new_left_context > input.left_context) {
176  if (!nnet_example_warned_left) {
177  nnet_example_warned_left = true;
178  KALDI_WARN << "Requested left-context " << new_left_context
179  << " exceeds input left-context " << input.left_context
180  << ", will not warn again.";
181  }
182  new_left_context = input.left_context;
183  }
184  if (new_right_context > input_right_context) {
187  KALDI_WARN << "Requested right-context " << new_right_context
188  << " exceeds input right-context " << input_right_context
189  << ", will not warn again.";
190  }
191  new_right_context = input_right_context;
192  }
193 
194  int32 new_tot_frames = new_left_context + new_num_frames + new_right_context,
195  left_frames_lost = (input.left_context - new_left_context) + start_frame;
196 
197  CompressedMatrix new_input_frames(input.input_frames,
198  left_frames_lost,
199  new_tot_frames,
200  0, input.input_frames.NumCols());
201  new_input_frames.Swap(&input_frames); // swap with class-member.
202  left_context = new_left_context; // set class-member.
203  labels.clear();
204  labels.insert(labels.end(),
205  input.labels.begin() + start_frame,
206  input.labels.begin() + start_frame + new_num_frames);
207 }
208 
210  std::vector<NnetExample> *examples) {
211  KALDI_ASSERT(!examples->empty());
212  empty_semaphore_.Wait();
213  KALDI_ASSERT(examples_.empty());
214  examples_.swap(*examples);
215  full_semaphore_.Signal();
216 }
217 
219  empty_semaphore_.Wait();
220  KALDI_ASSERT(examples_.empty());
221  done_ = true;
222  full_semaphore_.Signal();
223 }
224 
226  std::vector<NnetExample> *examples) {
227  full_semaphore_.Wait();
228  if (done_) {
229  KALDI_ASSERT(examples_.empty());
230  full_semaphore_.Signal(); // Increment the semaphore so
231  // the call by the next thread will not block.
232  return false; // no examples to return-- all finished.
233  } else {
234  KALDI_ASSERT(!examples_.empty() && examples->empty());
235  examples->swap(examples_);
236  empty_semaphore_.Signal();
237  return true;
238  }
239 }
240 
241 
242 void DiscriminativeNnetExample::Write(std::ostream &os,
243  bool binary) const {
244  // Note: weight, num_ali, den_lat, input_frames, left_context and spk_info are
245  // members. This is a struct.
246  WriteToken(os, binary, "<DiscriminativeNnetExample>");
247  WriteToken(os, binary, "<Weight>");
248  WriteBasicType(os, binary, weight);
249  WriteToken(os, binary, "<NumAli>");
250  WriteIntegerVector(os, binary, num_ali);
251  if (!WriteCompactLattice(os, binary, den_lat)) {
252  // We can't return error status from this function so we
253  // throw an exception.
254  KALDI_ERR << "Error writing CompactLattice to stream";
255  }
256  WriteToken(os, binary, "<InputFrames>");
257  {
258  CompressedMatrix cm(input_frames); // Note: this can be read as a regular
259  // matrix.
260  cm.Write(os, binary);
261  }
262  WriteToken(os, binary, "<LeftContext>");
263  WriteBasicType(os, binary, left_context);
264  WriteToken(os, binary, "<SpkInfo>");
265  spk_info.Write(os, binary);
266  WriteToken(os, binary, "</DiscriminativeNnetExample>");
267 }
268 
269 void DiscriminativeNnetExample::Read(std::istream &is,
270  bool binary) {
271  // Note: weight, num_ali, den_lat, input_frames, left_context and spk_info are
272  // members. This is a struct.
273  ExpectToken(is, binary, "<DiscriminativeNnetExample>");
274  ExpectToken(is, binary, "<Weight>");
275  ReadBasicType(is, binary, &weight);
276  ExpectToken(is, binary, "<NumAli>");
277  ReadIntegerVector(is, binary, &num_ali);
278  CompactLattice *den_lat_tmp = NULL;
279  if (!ReadCompactLattice(is, binary, &den_lat_tmp) || den_lat_tmp == NULL) {
280  // We can't return error status from this function so we
281  // throw an exception.
282  KALDI_ERR << "Error reading CompactLattice from stream";
283  }
284  den_lat = *den_lat_tmp;
285  delete den_lat_tmp;
286  ExpectToken(is, binary, "<InputFrames>");
287  input_frames.Read(is, binary);
288  ExpectToken(is, binary, "<LeftContext>");
289  ReadBasicType(is, binary, &left_context);
290  ExpectToken(is, binary, "<SpkInfo>");
291  spk_info.Read(is, binary);
292  ExpectToken(is, binary, "</DiscriminativeNnetExample>");
293 }
294 
296  KALDI_ASSERT(weight > 0.0);
297  KALDI_ASSERT(!num_ali.empty());
298  int32 num_frames = static_cast<int32>(num_ali.size());
299 
300 
301  std::vector<int32> times;
302  int32 num_frames_den = CompactLatticeStateTimes(den_lat, &times);
303  KALDI_ASSERT(num_frames == num_frames_den);
304  KALDI_ASSERT(input_frames.NumRows() >= left_context + num_frames);
305 }
306 
307 
308 } // namespace nnet2
309 } // namespace kaldi
CompressedMatrix input_frames
The input data, with NumRows() >= labels.size() + left_context; it includes features to the left and ...
Definition: nnet-example.h:49
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:36
void AcceptExamples(std::vector< NnetExample > *examples)
The following function is called by the code that reads in the examples, with a batch of examples...
void Swap(CompressedMatrix *other)
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 left_context
The number of frames of left context (we can work out the #frames of right context from input_frames...
Definition: nnet-example.h:53
void ExamplesDone()
The following function is called by the code that reads in the examples, when we&#39;re done reading exam...
void Read(std::istream &is, bool binary)
Definition: nnet-example.cc:80
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void Write(std::ostream &os, bool binary) const
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
static bool nnet_example_warned_right
bool ProvideExamples(std::vector< NnetExample > *examples)
This function is called by the code that does the training.
#define KALDI_ERR
Definition: kaldi-error.h:147
bool HasSimpleLabels(const NnetExample &eg, std::vector< int32 > *simple_labels)
Definition: nnet-example.cc:32
int32 CompactLatticeStateTimes(const CompactLattice &lat, vector< int32 > *times)
As LatticeStateTimes, but in the CompactLattice format.
#define KALDI_WARN
Definition: kaldi-error.h:150
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
int32 GetLabelSingle(int32 frame, BaseFloat *weight=NULL)
Get the maximum weight label (pdf_id and weight) of this frame of this example.
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool WriteCompactLattice(std::ostream &os, bool binary, const CompactLattice &t)
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
static bool nnet_example_warned_left
std::vector< std::vector< std::pair< int32, BaseFloat > > > labels
The label(s) for each frame in a sequence of frames; in the normal case, this will be just [ [ (pdf-i...
Definition: nnet-example.h:43
MatrixIndexT NumCols() const
Returns number of columns (or zero for emtpy matrix).
void SetLabelSingle(int32 frame, int32 pdf_id, BaseFloat weight=1.0)
Set the label of this frame of this example to the specified pdf_id with the specified weight...
bool ReadCompactLattice(std::istream &is, bool binary, CompactLattice **clat)
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: nnet-example.cc:46
Vector< BaseFloat > spk_info
The speaker-specific input, if any, or an empty vector if we&#39;re not using this features.
Definition: nnet-example.h:58