nnet-common.cc
Go to the documentation of this file.
1 // nnet3/nnet-common.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "nnet3/nnet-common.h"
22 
23 namespace kaldi {
24 namespace nnet3 {
25 
26 // Don't write with too many markers as we don't want to take up too much space.
27 void Index::Write(std::ostream &os, bool binary) const {
28  // writing this token will make it easier to write back-compatible code later
29  // on.
30  WriteToken(os, binary, "<I1>");
31  WriteBasicType(os, binary, n);
32  WriteBasicType(os, binary, t);
33  WriteBasicType(os, binary, x);
34 }
35 
36 
37 void Index::Read(std::istream &is, bool binary) {
38  ExpectToken(is, binary, "<I1>");
39  ReadBasicType(is, binary, &n);
40  ReadBasicType(is, binary, &t);
41  ReadBasicType(is, binary, &x);
42 }
43 
44 
46  std::ostream &os,
47  const std::vector<Index> &vec,
48  int32 i) {
49  bool binary = true;
50  const Index &index = vec[i];
51  if (i == 0) {
52  // we don't use std::abs(index.t) < 125 here because it doesn't have the
53  // right (or even well-defined) behavior for
54  // index.t == std::numeric_limits<int32>::min().
55  if (index.n == 0 && index.x == 0 &&
56  index.t > -125 && index.t < 125) {
57  // handle this common case in one character.
58  os.put(static_cast<signed char>(index.t));
59  } else { // handle the general case less efficiently.
60  os.put(127);
61  WriteBasicType(os, binary, index.n);
62  WriteBasicType(os, binary, index.t);
63  WriteBasicType(os, binary, index.x);
64  }
65  } else {
66  Index last_index = vec[i-1];
67  // we don't do if (std::abs(index.t - last_index.t) < 125)
68  // below because this doesn't work right if that difference
69  // equals std::numeric_limits<int32>::min().
70  if (index.n == last_index.n && index.x == last_index.x &&
71  index.t - last_index.t < 125 &&
72  index.t - last_index.t > -125) {
73  signed char c = index.t - last_index.t;
74  os.put(c);
75  } else { // handle the general case less efficiently.
76  os.put(127);
77  WriteBasicType(os, binary, index.n);
78  WriteBasicType(os, binary, index.t);
79  WriteBasicType(os, binary, index.x);
80  }
81  }
82  if (!os.good())
83  KALDI_ERR << "Output stream error detected";
84 }
85 
86 
88  std::istream &is,
89  int32 i,
90  std::vector<Index> *vec) {
91  bool binary = true;
92  Index &index = (*vec)[i];
93  if (!is.good())
94  KALDI_ERR << "End of file while reading vector of Index.";
95  signed char c = is.get();
96  if (i == 0) {
97  if (std::abs(int(c)) < 125) {
98  index.n = 0;
99  index.t = c;
100  index.x = 0;
101  } else {
102  if (c != 127)
103  KALDI_ERR << "Unexpected character " << c
104  << " encountered while reading Index vector.";
105  ReadBasicType(is, binary, &(index.n));
106  ReadBasicType(is, binary, &(index.t));
107  ReadBasicType(is, binary, &(index.x));
108  }
109  } else {
110  Index &last_index = (*vec)[i-1];
111  if (std::abs(int(c)) < 125) {
112  index.n = last_index.n;
113  index.t = last_index.t + c;
114  index.x = last_index.x;
115  } else {
116  if (c != 127)
117  KALDI_ERR << "Unexpected character " << c
118  << " encountered while reading Index vector.";
119  ReadBasicType(is, binary, &(index.n));
120  ReadBasicType(is, binary, &(index.t));
121  ReadBasicType(is, binary, &(index.x));
122  }
123  }
124 }
125 
126 void WriteIndexVector(std::ostream &os, bool binary,
127  const std::vector<Index> &vec) {
128  // This token will make it easier to write back-compatible code if we later
129  // change the format.
130  WriteToken(os, binary, "<I1V>");
131  int32 size = vec.size();
132  WriteBasicType(os, binary, size);
133  if (!binary) { // In text mode we just use the native Write functionality.
134  for (int32 i = 0; i < size; i++)
135  vec[i].Write(os, binary);
136  } else {
137  for (int32 i = 0; i < size; i++)
139  }
140 }
141 
142 
143 void ReadIndexVector(std::istream &is, bool binary,
144  std::vector<Index> *vec) {
145  ExpectToken(is, binary, "<I1V>");
146  int32 size;
147  ReadBasicType(is, binary, &size);
148  if (size < 0) {
149  KALDI_ERR << "Error reading Index vector: size = "
150  << size;
151  }
152  vec->resize(size);
153  if (!binary) {
154  for (int32 i = 0; i < size; i++)
155  (*vec)[i].Read(is, binary);
156  } else {
157  for (int32 i = 0; i < size; i++)
159  }
160 }
161 
163  std::ostream &os,
164  const std::vector<Cindex> &vec,
165  int32 i) {
166  bool binary = true;
167  int32 node_index = vec[i].first;
168  const Index &index = vec[i].second;
169  if (i == 0 || node_index != vec[i-1].first) {
170  // divide using '|' into ranges that each have all the same node name, like:
171  // [node_1: index_1 index_2] [node_2: index_3 index_4] Caution: '|' is
172  // character 124 so we have to avoid that character in places where it might
173  // be confused with this separator.
174  os.put('|');
175  WriteBasicType(os, binary, node_index);
176  }
177  if (i == 0) {
178  // we don't need to be concerned about reserving space for character 124
179  // ('|') here, since (wastefully) '|' is always printed for i == 0.
180  //
181  // we don't use std::abs(index.t) < 125 here because it doesn't have the
182  // right (or even well-defined) behavior for
183  // index.t == std::numeric_limits<int32>::min().
184  if (index.n == 0 && index.x == 0 &&
185  index.t > -125 && index.t < 125) {
186  // handle this common case in one character.
187  os.put(static_cast<signed char>(index.t));
188  } else if (index.t == 0 && index.x == 0 &&
189  (index.n == 0 || index.n == 1)) {
190  // handle this common case in one character.
191  os.put(static_cast<signed char>(index.n + 125));
192  } else { // handle the general case less efficiently.
193  os.put(127);
194  WriteBasicType(os, binary, index.n);
195  WriteBasicType(os, binary, index.t);
196  WriteBasicType(os, binary, index.x);
197  }
198  } else {
199  const Index &last_index = vec[i-1].second;
200  // we don't do if std::abs(index.t - last_index.t) < 124
201  // below because it doesn't work right if the difference
202  // equals std::numeric_limits<int32>::min().
203  if (index.n == last_index.n && index.x == last_index.x &&
204  index.t - last_index.t < 124 &&
205  index.t - last_index.t > -124) {
206  signed char c = index.t - last_index.t;
207  os.put(c);
208  // note: we have to reserve character 124 ('|') for when 'n' or 'x'
209  // changes.
210  } else if (index.t == last_index.t && index.x == last_index.x &&
211  (index.n == last_index.n || index.n == last_index.n + 1)) {
212  os.put(125 + index.n - last_index.n);
213  } else { // handle the general case less efficiently.
214  os.put(127);
215  WriteBasicType(os, binary, index.n);
216  WriteBasicType(os, binary, index.t);
217  WriteBasicType(os, binary, index.x);
218  }
219  }
220  if (!os.good())
221  KALDI_ERR << "Output stream error detected";
222 }
223 
225  std::istream &is,
226  int32 i,
227  std::vector<Cindex> *vec) {
228  bool binary = true;
229  Index &index = (*vec)[i].second;
230  if (!is.good())
231  KALDI_ERR << "End of file while reading vector of Cindex.";
232  if (is.peek() == static_cast<int>('|')) {
233  is.get();
234  ReadBasicType(is, binary, &((*vec)[i].first));
235  } else {
236  KALDI_ASSERT(i != 0);
237  (*vec)[i].first = (*vec)[i-1].first;
238  }
239  signed char c = is.get();
240  if (i == 0) {
241  if (std::abs(int(c)) < 125) {
242  index.n = 0;
243  index.t = c;
244  index.x = 0;
245  } else if (c == 125 || c == 126) {
246  index.n = c - 125;
247  index.t = 0;
248  index.x = 0;
249  } else {
250  if (c != 127)
251  KALDI_ERR << "Unexpected character " << c
252  << " encountered while reading Cindex vector.";
253  ReadBasicType(is, binary, &(index.n));
254  ReadBasicType(is, binary, &(index.t));
255  ReadBasicType(is, binary, &(index.x));
256  }
257  } else {
258  Index &last_index = (*vec)[i-1].second;
259  if (std::abs(int(c)) < 124) {
260  index.n = last_index.n;
261  index.t = last_index.t + c;
262  index.x = last_index.x;
263  } else if (c == 125 || c == 126) {
264  index.n = last_index.n + c - 125;
265  index.t = last_index.t;
266  index.x = last_index.x;
267  } else {
268  if (c != 127)
269  KALDI_ERR << "Unexpected character " << c
270  << " encountered while reading Cindex vector.";
271  ReadBasicType(is, binary, &(index.n));
272  ReadBasicType(is, binary, &(index.t));
273  ReadBasicType(is, binary, &(index.x));
274  }
275  }
276 }
277 
278 // This function writes elements of a Cindex vector in a compact form.
279 // which is similar as the output of PrintCindexes. The vector is divided
280 // into ranges that each have all the same node name, like:
281 // [node_1: index_1 index_2] [node_2: index_3 index_4]
282 void WriteCindexVector(std::ostream &os, bool binary,
283  const std::vector<Cindex> &vec) {
284  // This token will make it easier to write back-compatible code if we later
285  // change the format.
286  WriteToken(os, binary, "<I1V>");
287  int32 size = vec.size();
288  WriteBasicType(os, binary, size);
289  if (!binary) { // In text mode we just use the native Write functionality.
290  for (int32 i = 0; i < size; i++) {
291  int32 node_index = vec[i].first;
292  if (i == 0 || node_index != vec[i-1].first) {
293  if (i > 0)
294  os.put(']');
295  os.put('[');
296  WriteBasicType(os, binary, node_index);
297  os.put(':');
298  }
299  vec[i].second.Write(os, binary);
300  if (i == size - 1)
301  os.put(']');
302  }
303  } else {
304  for (int32 i = 0; i < size; i++)
306  }
307 }
308 
309 void ReadCindexVector(std::istream &is, bool binary,
310  std::vector<Cindex> *vec) {
311  ExpectToken(is, binary, "<I1V>");
312  int32 size;
313  ReadBasicType(is, binary, &size);
314  if (size < 0) {
315  KALDI_ERR << "Error reading Index vector: size = "
316  << size;
317  }
318  vec->resize(size);
319  if (!binary) {
320  for (int32 i = 0; i < size; i++) {
321  is >> std::ws;
322  if (is.peek() == static_cast<int>(']') || i == 0) {
323  if (i != 0)
324  is.get();
325  is >> std::ws;
326  if (is.peek() == static_cast<int>('[')) {
327  is.get();
328  } else {
329  KALDI_ERR << "ReadCintegerVector: expected to see [, saw "
330  << is.peek() << ", at file position " << is.tellg();
331  }
332  ReadBasicType(is, binary, &((*vec)[i].first));
333  is >> std::ws;
334  if (is.peek() == static_cast<int>(':')) {
335  is.get();
336  } else {
337  KALDI_ERR << "ReadCintegerVector: expected to see :, saw "
338  << is.peek() << ", at file position " << is.tellg();
339  }
340  } else {
341  (*vec)[i].first = (*vec)[i-1].first;
342  }
343  (*vec)[i].second.Read(is, binary);
344  if (i == size - 1) {
345  is >> std::ws;
346  if (is.peek() == static_cast<int>(']')) {
347  is.get();
348  } else {
349  KALDI_ERR << "ReadCintegerVector: expected to see ], saw "
350  << is.peek() << ", at file position " << is.tellg();
351  }
352  }
353  }
354  } else {
355  for (int32 i = 0; i < size; i++)
357  }
358 }
359 
360 size_t IndexHasher::operator () (const Index &index) const noexcept {
361  // The numbers that appear below were chosen arbitrarily from a list of primes
362  return index.n +
363  1619 * index.t +
364  15649 * index.x;
365 }
366 
367 size_t CindexHasher::operator () (const Cindex &cindex) const noexcept {
368  // The numbers that appear below were chosen arbitrarily from a list of primes
369  return cindex.first +
370  1619 * cindex.second.n +
371  15649 * cindex.second.t +
372  89809 * cindex.second.x;
373 
374 }
375 
377  const std::vector<Cindex> &cindex_vector) const noexcept {
378  // this is an arbitrarily chosen prime.
379  size_t kPrime = 23539, ans = 0;
380  std::vector<Cindex>::const_iterator iter = cindex_vector.begin(),
381  end = cindex_vector.end();
382  CindexHasher cindex_hasher;
383  for (; iter != end; ++iter)
384  ans = cindex_hasher(*iter) + kPrime * ans;
385  return ans;
386 }
387 
389  const std::vector<Index> &index_vector) const noexcept {
390  size_t n1 = 15, n2 = 10; // n1 and n2 are used to extract only a subset of
391  // elements to hash; this makes the hasher faster by
392  // skipping over more elements. Setting n1 large or
393  // n2 to 1 would make the hasher consider all
394  // elements.
395  size_t len = index_vector.size();
396  // all long-ish numbers appearing below are randomly chosen primes.
397  size_t ans = 1433 + 34949 * len;
398  std::vector<Index>::const_iterator iter = index_vector.begin(),
399  end = index_vector.end(), med = end;
400  if (n1 < len)
401  med = iter + n1;
402 
403  for (; iter != med; ++iter) {
404  ans += iter->n * 1619;
405  ans += iter->t * 15649;
406  ans += iter->x * 89809;
407  }
408  // after the first n1 values, look only at every n2'th value. this makes the
409  // hashing much faster, and in the kinds of structures that we actually deal
410  // with, we shouldn't get unnecessary hash collisions as a result of this
411  // optimization.
412  for (; iter < end; iter += n2) {
413  ans += iter->n * 1619;
414  ans += iter->t * 15649;
415  ans += iter->x * 89809;
416  // The following if-statement was introduced in order to fix an
417  // out-of-range iterator problem on Windows.
418  if (n2 > len || iter >= end - n2)
419  break;
420  }
421  return ans;
422 }
423 
424 std::ostream &operator << (std::ostream &ostream, const Index &index) {
425  return ostream << '(' << index.n << ' ' << index.t << ' ' << index.x << ')';
426 }
427 
428 std::ostream &operator << (std::ostream &ostream, const Cindex &cindex) {
429  return ostream << '(' << cindex.first << ' ' << cindex.second << ')';
430 }
431 
432 void PrintCindex(std::ostream &os, const Cindex &cindex,
433  const std::vector<std::string> &node_names) {
434  KALDI_ASSERT(static_cast<size_t>(cindex.first) < node_names.size());
435  os << node_names[cindex.first] << "(" << cindex.second.n << ","
436  << cindex.second.t;
437  if (cindex.second.x != 0)
438  os << "," << cindex.second.x;
439  os << ")";
440 }
441 
442 void PrintIndexes(std::ostream &os,
443  const std::vector<Index> &indexes) {
444  if (indexes.empty()) {
445  os << "[ ]";
446  return;
447  }
448  // If the string is longer than 'max_string_length' characters, it will
449  // be summarized with '...' in the middle.
450  size_t max_string_length = 200;
451  std::ostringstream os_temp;
452 
453  // range_starts will be the starts of ranges (with consecutive t values and
454  // the same n value and zero x values) that we compactly print. we'll append
455  // "end" to range_starts for convenience.n
456  std::vector<int32> range_starts;
457  int32 cur_start = 0, end = indexes.size();
458  for (int32 i = cur_start; i < end; i++) {
459  const Index &index = indexes[i];
460  if (i > cur_start &&
461  (index.t != indexes[i-1].t + 1 ||
462  index.n != indexes[i-1].n ||
463  index.x != indexes[i-1].x)) {
464  range_starts.push_back(cur_start);
465  cur_start = i;
466  }
467  }
468  range_starts.push_back(cur_start);
469  range_starts.push_back(end);
470  os_temp << "[";
471  int32 num_ranges = range_starts.size() - 1;
472  for (int32 r = 0; r < num_ranges; r++) {
473  int32 range_start = range_starts[r], range_end = range_starts[r+1];
474  KALDI_ASSERT(range_end > range_start);
475  os_temp << "(" << indexes[range_start].n << ",";
476  if (range_end == range_start + 1)
477  os_temp << indexes[range_start].t;
478  else
479  os_temp << indexes[range_start].t << ":" << indexes[range_end - 1].t;
480  if (indexes[range_start].x != 0)
481  os_temp << "," << indexes[range_start].x;
482  os_temp << ")";
483  if (r + 1 < num_ranges)
484  os_temp << ", ";
485  }
486  os_temp << "]";
487 
488  std::string str = os_temp.str();
489  if (str.size() <= max_string_length) {
490  os << str;
491  } else {
492  size_t len = str.size();
493  os << str.substr(0, max_string_length / 2) << " ... "
494  << str.substr(len - max_string_length / 2);
495  }
496 }
497 
498 void PrintCindexes(std::ostream &ostream,
499  const std::vector<Cindex> &cindexes,
500  const std::vector<std::string> &node_names) {
501  int32 num_cindexes = cindexes.size();
502  if (num_cindexes == 0) {
503  ostream << "[ ]";
504  return;
505  }
506  int32 cur_offset = 0;
507  std::vector<Index> indexes;
508  indexes.reserve(cindexes.size());
509  while (cur_offset < num_cindexes) {
510  int32 cur_node_index = cindexes[cur_offset].first;
511  while (cur_offset < num_cindexes &&
512  cindexes[cur_offset].first == cur_node_index) {
513  indexes.push_back(cindexes[cur_offset].second);
514  cur_offset++;
515  }
516  KALDI_ASSERT(static_cast<size_t>(cur_node_index) < node_names.size());
517  const std::string &node_name = node_names[cur_node_index];
518  ostream << node_name;
519  PrintIndexes(ostream, indexes);
520  indexes.clear();
521  }
522 }
523 
524 
525 void PrintIntegerVector(std::ostream &os,
526  const std::vector<int32> &ints) {
527  if (ints.empty()) {
528  os << "[ ]";
529  return;
530  }
531  // range_starts will be the starts of ranges (with consecutive or identical
532  // values) that we compactly print. we'll append "end" to range_starts for
533  // convenience.
534  std::vector<int32> range_starts;
535  int32 cur_start = 0, end = ints.size();
536  for (int32 i = cur_start; i < end; i++) {
537  if (i > cur_start) {
538  int32 range_start_val = ints[cur_start],
539  range_start_plus_one_val = ints[cur_start+1],
540  cur_val = ints[i];
541  // if we have reached the end of a range...
542  if (!((range_start_plus_one_val == range_start_val &&
543  cur_val == range_start_val) ||
544  (range_start_plus_one_val == range_start_val + 1 &&
545  cur_val == range_start_val + i - cur_start))) {
546  range_starts.push_back(cur_start);
547  cur_start = i;
548  }
549  }
550  }
551  range_starts.push_back(cur_start);
552  range_starts.push_back(end);
553  os << "[";
554  int32 num_ranges = range_starts.size() - 1;
555  for (int32 r = 0; r < num_ranges; r++) {
556  int32 range_start = range_starts[r], range_end = range_starts[r+1];
557  KALDI_ASSERT(range_end > range_start);
558  if (range_end == range_start + 1)
559  os << ints[range_start];
560  else if (range_end == range_start + 2) // don't print ranges of 2.
561  os << ints[range_start] << ", " << ints[range_start+1];
562  else if (ints[range_start] == ints[range_start+1])
563  os << ints[range_start] << "x" << (range_end - range_start);
564  else
565  os << ints[range_start] << ":" << ints[range_end - 1];
566  if (r + 1 < num_ranges)
567  os << ", ";
568  }
569  os << "]";
570 }
571 
572 // this will be the most negative number representable as int32.
573 const int kNoTime = std::numeric_limits<int32>::min();
574 
575 } // namespace nnet3
576 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void WriteIndexVector(std::ostream &os, bool binary, const std::vector< Index > &vec)
Definition: nnet-common.cc:126
size_t operator()(const std::vector< Cindex > &cindex_vector) const noexcept
Definition: nnet-common.cc:376
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
static void WriteIndexVectorElementBinary(std::ostream &os, const std::vector< Index > &vec, int32 i)
Definition: nnet-common.cc:45
void PrintCindex(std::ostream &os, const Cindex &cindex, const std::vector< std::string > &node_names)
Definition: nnet-common.cc:432
void ReadCindexVector(std::istream &is, bool binary, std::vector< Cindex > *vec)
Definition: nnet-common.cc:309
void PrintIndexes(std::ostream &os, const std::vector< Index > &indexes)
this will only be used for pretty-printing.
Definition: nnet-common.cc:442
kaldi::int32 int32
std::ostream & operator<<(std::ostream &ostream, const Index &index)
Definition: nnet-common.cc:424
static void ReadIndexVectorElementBinary(std::istream &is, int32 i, std::vector< Index > *vec)
Definition: nnet-common.cc:87
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
std::pair< int32, Index > Cindex
Definition: nnet-common.h:115
void Read(std::istream &os, bool binary)
Definition: nnet-common.cc:37
void WriteCindexVector(std::ostream &os, bool binary, const std::vector< Cindex > &vec)
Definition: nnet-common.cc:282
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void Write(std::ostream &os, bool binary) const
Definition: nnet-common.cc:27
static void ReadCindexVectorElementBinary(std::istream &is, int32 i, std::vector< Cindex > *vec)
Definition: nnet-common.cc:224
size_t operator()(const Cindex &cindex) const noexcept
Definition: nnet-common.cc:367
#define KALDI_ERR
Definition: kaldi-error.h:147
void PrintIntegerVector(std::ostream &os, const std::vector< int32 > &ints)
Definition: nnet-common.cc:525
size_t operator()(const Index &cindex) const noexcept
Definition: nnet-common.cc:360
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void ReadIndexVector(std::istream &is, bool binary, std::vector< Index > *vec)
Definition: nnet-common.cc:143
void PrintCindexes(std::ostream &ostream, const std::vector< Cindex > &cindexes, const std::vector< std::string > &node_names)
this will only be used for pretty-printing.
Definition: nnet-common.cc:498
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static void WriteCindexVectorElementBinary(std::ostream &os, const std::vector< Cindex > &vec, int32 i)
Definition: nnet-common.cc:162
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
size_t operator()(const std::vector< Index > &index_vector) const noexcept
Definition: nnet-common.cc:388
const int kNoTime
Definition: nnet-common.cc:573