nnet-common.h
Go to the documentation of this file.
1 // nnet3/nnet-common.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Pove
4 // 2016 Xiaohui Zhang
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_NNET_COMMON_H_
22 #define KALDI_NNET3_NNET_COMMON_H_
23 
24 #include "base/kaldi-common.h"
25 #include "util/common-utils.h"
26 #include "itf/options-itf.h"
27 #include "matrix/matrix-lib.h"
29 
30 #include <iostream>
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
35 
44 struct Index {
45  int32 n; // member-index of minibatch, or zero.
46  int32 t; // time-frame.
47  int32 x; // this may come in useful in convoluational approaches.
48  // ... it is possible to add extra index here, if needed.
49  Index(): n(0), t(0), x(0) { }
50  Index(int32 n, int32 t, int32 x = 0): n(n), t(t), x(x) { }
51 
52  bool operator == (const Index &a) const {
53  return n == a.n && t == a.t && x == a.x;
54  }
55  bool operator != (const Index &a) const {
56  return n != a.n || t != a.t || x != a.x;
57  }
58  bool operator < (const Index &a) const {
59  if (t < a.t) { return true; }
60  else if (t > a.t) { return false; }
61  else if (x < a.x) { return true; }
62  else if (x > a.x) { return false; }
63  else return (n < a.n);
64  }
65  Index operator + (const Index &other) const {
66  return Index(n+other.n, t+other.t, x+other.x);
67  }
68  Index &operator += (const Index &other) {
69  n += other.n;
70  t += other.t;
71  x += other.x;
72  return *this;
73  }
74 
75  void Write(std::ostream &os, bool binary) const;
76 
77  void Read(std::istream &os, bool binary);
78 };
79 
80 
81 // this will be the most negative number representable as int32. It is used as
82 // the 't' value when we need to mark an 'invalid' index. This can happen with
83 // certain non-simple components whose ReorderIndexes() function need to insert
84 // spaces into their inputs or outputs.
85 extern const int kNoTime;
86 
87 // This struct can be used as a comparison object when you want to
88 // sort the indexes first on n, then x, then t (Index's own comparison
89 // object will sort first on t, then n, then x)
90 struct IndexLessNxt {
91  inline bool operator ()(const Index &a, const Index &b) const {
92  if (a.n < b.n) { return true; }
93  else if (a.n > b.n) { return false; }
94  else if (a.x < b.x) { return true; }
95  else if (a.x > b.x) { return false; }
96  else return (a.t < b.t);
97  }
98 };
99 
100 
101 // this will be used only for debugging output.
102 std::ostream &operator << (std::ostream &ostream, const Index &index);
103 
104 
105 void WriteIndexVector(std::ostream &os, bool binary,
106  const std::vector<Index> &vec);
107 
108 void ReadIndexVector(std::istream &is, bool binary,
109  std::vector<Index> *vec);
110 
111 
112 /* A Cindex is a pair of a node-index (i.e. the index of a NetworkNode) and an
113  Index. It's frequently used so it gets its own typedef.
114  */
115 typedef std::pair<int32, Index> Cindex;
116 
117 struct IndexHasher {
118  size_t operator () (const Index &cindex) const noexcept;
119 };
120 
121 struct CindexHasher {
122  size_t operator () (const Cindex &cindex) const noexcept;
123 };
124 
126  size_t operator () (const std::vector<Cindex> &cindex_vector) const noexcept;
127 };
128 
129 // Note: because IndexVectorHasher is used in some things where we really need
130 // it to be fast, it doesn't look at all the indexes, just most of them.
132  size_t operator () (const std::vector<Index> &index_vector) const noexcept;
133 };
134 
135 
136 
137 // this will only be used for pretty-printing.
138 void PrintCindex(std::ostream &ostream, const Cindex &cindex,
139  const std::vector<std::string> &node_names);
140 
148 void PrintIndexes(std::ostream &ostream,
149  const std::vector<Index> &indexes);
150 
158 void PrintCindexes(std::ostream &ostream,
159  const std::vector<Cindex> &cindexes,
160  const std::vector<std::string> &node_names);
161 
163 void AppendCindexes(int32 node, const std::vector<Index> &indexes,
164  std::vector<Cindex> *out);
165 
166 void WriteCindexVector(std::ostream &os, bool binary,
167  const std::vector<Cindex> &vec);
168 
169 void ReadCindexVector(std::istream &is, bool binary,
170  std::vector<Cindex> *vec);
171 
172 // this function prints a vector of integers in a human-readable
173 // way, for pretty-printing; it outputs ranges and repeats in
174 // a compact form e.g. [ -1x10, 1:20, 25:40 ]
175 void PrintIntegerVector(std::ostream &ostream,
176  const std::vector<int32> &ints);
177 
178 
179 // this will be used only for debugging output.
180 std::ostream &operator << (std::ostream &ostream, const Cindex &cindex);
181 
182 
183 // some forward declarations.
184 class Component;
185 class Nnet;
186 struct MiscComputationInfo;
187 
188 } // namespace nnet3
189 } // namespace kaldi
190 
191 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void WriteIndexVector(std::ostream &os, bool binary, const std::vector< Index > &vec)
Definition: nnet-common.cc:126
bool operator<(const Index &a) const
Definition: nnet-common.h:58
Abstract base-class for neural-net components.
void PrintCindex(std::ostream &os, const Cindex &cindex, const std::vector< std::string > &node_names)
Definition: nnet-common.cc:432
void ReadCindexVector(std::istream &is, bool binary, std::vector< Cindex > *vec)
Definition: nnet-common.cc:309
void PrintIndexes(std::ostream &os, const std::vector< Index > &indexes)
this will only be used for pretty-printing.
Definition: nnet-common.cc:442
kaldi::int32 int32
std::ostream & operator<<(std::ostream &ostream, const Index &index)
Definition: nnet-common.cc:424
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
std::pair< int32, Index > Cindex
Definition: nnet-common.h:115
void Read(std::istream &os, bool binary)
Definition: nnet-common.cc:37
void WriteCindexVector(std::ostream &os, bool binary, const std::vector< Cindex > &vec)
Definition: nnet-common.cc:282
bool operator==(const Index &a) const
Definition: nnet-common.h:52
bool operator!=(const Index &a) const
Definition: nnet-common.h:55
void Write(std::ostream &os, bool binary) const
Definition: nnet-common.cc:27
Index(int32 n, int32 t, int32 x=0)
Definition: nnet-common.h:50
void PrintIntegerVector(std::ostream &os, const std::vector< int32 > &ints)
Definition: nnet-common.cc:525
void AppendCindexes(int32 node, const std::vector< Index > &indexes, std::vector< Cindex > *out)
Appends to &#39;out&#39; the pairs (node, indexes[0]), (node, indexes[1]), ...
void ReadIndexVector(std::istream &is, bool binary, std::vector< Index > *vec)
Definition: nnet-common.cc:143
void PrintCindexes(std::ostream &ostream, const std::vector< Cindex > &cindexes, const std::vector< std::string > &node_names)
this will only be used for pretty-printing.
Definition: nnet-common.cc:498
Index & operator+=(const Index &other)
Definition: nnet-common.h:68
Index operator+(const Index &other) const
Definition: nnet-common.h:65
const int kNoTime
Definition: nnet-common.cc:573