nnet-component.cc
Go to the documentation of this file.
1 // nnet/nnet-component.cc
2 
3 // Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include <algorithm>
22 #include <sstream>
23 
24 #include "nnet/nnet-component.h"
25 
26 #include "nnet/nnet-nnet.h"
27 #include "nnet/nnet-activation.h"
28 #include "nnet/nnet-kl-hmm.h"
31 #include "nnet/nnet-rbm.h"
32 #include "nnet/nnet-various.h"
33 
37 
40 #include "nnet/nnet-recurrent.h"
41 
47 
48 namespace kaldi {
49 namespace nnet1 {
50 
52  { Component::kAffineTransform, "<AffineTransform>" },
53  { Component::kLinearTransform, "<LinearTransform>" },
54  { Component::kConvolutionalComponent, "<ConvolutionalComponent>" },
55  { Component::kLstmProjected, "<LstmProjected>" },
56  { Component::kLstmProjected, "<LstmProjectedStreams>" }, // bwd compat.
57  { Component::kBlstmProjected, "<BlstmProjected>" },
58  { Component::kBlstmProjected, "<BlstmProjectedStreams>" }, // bwd compat.
59  { Component::kRecurrentComponent, "<RecurrentComponent>" },
60  { Component::kSoftmax, "<Softmax>" },
61  { Component::kHiddenSoftmax, "<HiddenSoftmax>" },
62  { Component::kBlockSoftmax, "<BlockSoftmax>" },
63  { Component::kSigmoid, "<Sigmoid>" },
64  { Component::kTanh, "<Tanh>" },
65  { Component::kParametricRelu,"<ParametricRelu>" },
66  { Component::kDropout, "<Dropout>" },
67  { Component::kLengthNormComponent, "<LengthNormComponent>" },
68  { Component::kRbm, "<Rbm>" },
69  { Component::kSplice, "<Splice>" },
70  { Component::kCopy, "<Copy>" },
71  { Component::kAddShift, "<AddShift>" },
72  { Component::kRescale, "<Rescale>" },
73  { Component::kKlHmm, "<KlHmm>" },
74  { Component::kAveragePoolingComponent, "<AveragePoolingComponent>" },
75  { Component::kMaxPoolingComponent, "<MaxPoolingComponent>" },
76  { Component::kSentenceAveragingComponent, "<SentenceAveragingComponent>" },
77  { Component::kSimpleSentenceAveragingComponent, "<SimpleSentenceAveragingComponent>" },
78  { Component::kFramePoolingComponent, "<FramePoolingComponent>" },
79  { Component::kParallelComponent, "<ParallelComponent>" },
80  { Component::kMultiBasisComponent, "<MultiBasisComponent>" },
81 };
82 
83 
85  // Retuns the 1st '<string>' corresponding to the type in 'kMarkerMap',
86  int32 N = sizeof(kMarkerMap) / sizeof(kMarkerMap[0]);
87  for (int i = 0; i < N; i++) {
88  if (kMarkerMap[i].key == t) return kMarkerMap[i].value;
89  }
90  KALDI_ERR << "Unknown type : " << t;
91  return NULL;
92 }
93 
95  std::string s_lowercase(s);
96  std::transform(s.begin(), s.end(), s_lowercase.begin(), ::tolower); // lc
97  int32 N = sizeof(kMarkerMap) / sizeof(kMarkerMap[0]);
98  for (int i = 0; i < N; i++) {
99  std::string m(kMarkerMap[i].value);
100  std::string m_lowercase(m);
101  std::transform(m.begin(), m.end(), m_lowercase.begin(), ::tolower);
102  if (s_lowercase == m_lowercase) return kMarkerMap[i].key;
103  }
104  KALDI_ERR << "Unknown 'Component' marker : '" << s << "'\n"
105  << "(isn't the model 'too old' or incompatible?)";
106  return kUnknown;
107 }
108 
109 
111  int32 input_dim, int32 output_dim) {
112  Component *ans = NULL;
113  switch (comp_type) {
115  ans = new AffineTransform(input_dim, output_dim);
116  break;
118  ans = new LinearTransform(input_dim, output_dim);
119  break;
121  ans = new ConvolutionalComponent(input_dim, output_dim);
122  break;
124  ans = new LstmProjected(input_dim, output_dim);
125  break;
127  ans = new BlstmProjected(input_dim, output_dim);
128  break;
130  ans = new RecurrentComponent(input_dim, output_dim);
131  break;
132  case Component::kSoftmax :
133  ans = new Softmax(input_dim, output_dim);
134  break;
136  ans = new HiddenSoftmax(input_dim, output_dim);
137  break;
139  ans = new BlockSoftmax(input_dim, output_dim);
140  break;
141  case Component::kSigmoid :
142  ans = new Sigmoid(input_dim, output_dim);
143  break;
144  case Component::kTanh :
145  ans = new Tanh(input_dim, output_dim);
146  break;
148  ans = new ParametricRelu(input_dim, output_dim);
149  break;
150  case Component::kDropout :
151  ans = new Dropout(input_dim, output_dim);
152  break;
154  ans = new LengthNormComponent(input_dim, output_dim);
155  break;
156  case Component::kRbm :
157  ans = new Rbm(input_dim, output_dim);
158  break;
159  case Component::kSplice :
160  ans = new Splice(input_dim, output_dim);
161  break;
162  case Component::kCopy :
163  ans = new CopyComponent(input_dim, output_dim);
164  break;
165  case Component::kAddShift :
166  ans = new AddShift(input_dim, output_dim);
167  break;
168  case Component::kRescale :
169  ans = new Rescale(input_dim, output_dim);
170  break;
171  case Component::kKlHmm :
172  ans = new KlHmm(input_dim, output_dim);
173  break;
175  ans = new SentenceAveragingComponent(input_dim, output_dim);
176  break;
178  ans = new SimpleSentenceAveragingComponent(input_dim, output_dim);
179  break;
181  ans = new AveragePoolingComponent(input_dim, output_dim);
182  break;
184  ans = new MaxPoolingComponent(input_dim, output_dim);
185  break;
187  ans = new FramePoolingComponent(input_dim, output_dim);
188  break;
190  ans = new ParallelComponent(input_dim, output_dim);
191  break;
193  ans = new MultiBasisComponent(input_dim, output_dim);
194  break;
195  case Component::kUnknown :
196  default :
197  KALDI_ERR << "Missing type: " << TypeToMarker(comp_type);
198  }
199  return ans;
200 }
201 
202 
203 Component* Component::Init(const std::string &conf_line) {
204  std::istringstream is(conf_line);
205  std::string component_type_string;
206  int32 input_dim, output_dim;
207 
208  // initialize component w/o internal data
209  ReadToken(is, false, &component_type_string);
210  ComponentType component_type = MarkerToType(component_type_string);
211  ExpectToken(is, false, "<InputDim>");
212  ReadBasicType(is, false, &input_dim);
213  ExpectToken(is, false, "<OutputDim>");
214  ReadBasicType(is, false, &output_dim);
215  Component *ans = NewComponentOfType(component_type, input_dim, output_dim);
216 
217  // initialize internal data with the remaining part of config line
218  ans->InitData(is);
219 
220  return ans;
221 }
222 
223 
224 Component* Component::Read(std::istream &is, bool binary) {
225  int32 dim_out, dim_in;
226  std::string token;
227 
228  int first_char = Peek(is, binary);
229  if (first_char == EOF) return NULL;
230 
231  ReadToken(is, binary, &token);
232  // Skip the optional initial token,
233  if (token == "<Nnet>") {
234  ReadToken(is, binary, &token);
235  }
236  // Network ends after terminal token appears,
237  if (token == "</Nnet>") {
238  return NULL;
239  }
240 
241  // Read the dims,
242  ReadBasicType(is, binary, &dim_out);
243  ReadBasicType(is, binary, &dim_in);
244 
245  // Create the component,
246  Component *ans = NewComponentOfType(MarkerToType(token), dim_in, dim_out);
247 
248  // Read the content,
249  ans->ReadData(is, binary);
250 
251  // 'Eat' the component separtor (can be already consumed by 'ReadData(.)'),
252  if ('<' == Peek(is, binary) && '!' == PeekToken(is, binary)) {
253  ExpectToken(is, binary, "<!EndOfComponent>");
254  }
255 
256  return ans;
257 }
258 
259 
260 void Component::Write(std::ostream &os, bool binary) const {
262  WriteBasicType(os, binary, OutputDim());
263  WriteBasicType(os, binary, InputDim());
264  if (!binary) os << "\n";
265  this->WriteData(os, binary);
266  WriteToken(os, binary, "<!EndOfComponent>"); // Write component separator.
267  if (!binary) os << "\n";
268 }
269 
270 
271 } // namespace nnet1
272 } // namespace kaldi
Deprecated!!!, keeping it as Katka Zmolikova used it in JSALT 2015.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
ConvolutionalComponent implements convolution over single axis (i.e.
MaxPoolingComponent : The input/output matrices are split to submatrices with width &#39;pool_stride_&#39;...
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
Rearrange the matrix columns according to the indices in copy_from_indices_.
Definition: nnet-various.h:146
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
static Component * Init(const std::string &conf_line)
Initialize component from a line in config file,.
static Component * Read(std::istream &is, bool binary)
Read the component from a stream (static method),.
int Peek(std::istream &is, bool binary)
Peek consumes whitespace (if binary == false) and then returns the peek() value of the stream...
Definition: io-funcs.cc:145
ComponentType
Component type identification mechanism,.
Rescale the data column-wise by a vector (can be used for global variance normalization) ...
Definition: nnet-various.h:404
virtual void ReadData(std::istream &is, bool binary)
Reads the component content.
A pair of type and marker,.
static const struct key_value kMarkerMap[]
The table with pairs of Component types and markers (defined in nnet-component.cc),.
Rescale the matrix-rows to have unit length (L2-norm).
Definition: nnet-various.h:244
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
const Component::ComponentType key
Adds shift to all the lines of the matrix (can be used for global mean normalization) ...
Definition: nnet-various.h:291
void Write(std::ostream &os, bool binary) const
Write the component to a stream,.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
int32 InputDim() const
Get the dimension of the input,.
static ComponentType MarkerToType(const std::string &s)
Converts marker to component type (case insensitive),.
virtual void WriteData(std::ostream &os, bool binary) const
Writes the component content.
#define KALDI_ERR
Definition: kaldi-error.h:147
virtual void InitData(std::istream &is)
Virtual interface for initialization and I/O,.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int PeekToken(std::istream &is, bool binary)
PeekToken will return the first character of the next token, or -1 if end of file.
Definition: io-funcs.cc:170
void Splice(const CuMatrixBase< Real > &src, const CuArray< int32 > &frame_offsets, CuMatrixBase< Real > *tgt)
Splice concatenates frames of src as specified in frame_offsets into tgt.
Definition: cu-math.cc:132
Component with recurrent connections, &#39;tanh&#39; non-linearity.
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Abstract class, building block of the network.
AveragePoolingComponent : The input/output matrices are split to submatrices with width &#39;pool_stride_...
int32 OutputDim() const
Get the dimension of the output,.
SimpleSentenceAveragingComponent does not have nested network, it is intended to be used inside of a ...
FramePoolingComponent : The input/output matrices are split to frames of width &#39;feature_dim_&#39;.
static Component * NewComponentOfType(ComponentType t, int32 input_dim, int32 output_dim)
Private members (descending classes cannot call this),.