nnet-component.h
Go to the documentation of this file.
1 // nnet/nnet-component.h
2 
3 // Copyright 2011-2016 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 
22 #ifndef KALDI_NNET_NNET_COMPONENT_H_
23 #define KALDI_NNET_NNET_COMPONENT_H_
24 
25 #include <iostream>
26 #include <string>
27 
28 #include "base/kaldi-common.h"
29 #include "matrix/matrix-lib.h"
30 #include "cudamatrix/cu-matrix.h"
31 #include "cudamatrix/cu-vector.h"
32 #include "nnet/nnet-trnopts.h"
33 
34 namespace kaldi {
35 namespace nnet1 {
36 
43 class Component {
45  public:
47  typedef enum {
48  kUnknown = 0x0,
49 
57 
67 
68  kTranform = 0x0400,
76 
77  kKlHmm = 0x0800,
78  kSentenceAveragingComponent, /* deprecated */
85  } ComponentType;
86 
88  struct key_value {
90  const char *value;
91  };
92 
95  static const struct key_value kMarkerMap[];
96 
98  static const char* TypeToMarker(ComponentType t);
99 
101  static ComponentType MarkerToType(const std::string &s);
102 
104  public:
105  Component(int32 input_dim, int32 output_dim):
106  input_dim_(input_dim),
107  output_dim_(output_dim)
108  { }
109 
110  virtual ~Component()
111  { }
112 
114  virtual Component* Copy() const = 0;
115 
117  virtual ComponentType GetType() const = 0;
118 
120  virtual bool IsUpdatable() const {
121  return false;
122  }
123 
125  virtual bool IsMultistream() const {
126  return false;
127  }
128 
130  int32 InputDim() const {
131  return input_dim_;
132  }
133 
135  int32 OutputDim() const {
136  return output_dim_;
137  }
138 
141 
144  void Backpropagate(const CuMatrixBase<BaseFloat> &in,
145  const CuMatrixBase<BaseFloat> &out,
146  const CuMatrixBase<BaseFloat> &out_diff,
147  CuMatrix<BaseFloat> *in_diff);
148 
150  static Component* Init(const std::string &conf_line);
151 
153  static Component* Read(std::istream &is, bool binary);
154 
156  void Write(std::ostream &os, bool binary) const;
157 
159  virtual std::string Info() const { return ""; }
160 
162  virtual std::string InfoGradient() const { return ""; }
163 
164 
166  protected:
168  virtual void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
169  CuMatrixBase<BaseFloat> *out) = 0;
170 
172  virtual void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in,
173  const CuMatrixBase<BaseFloat> &out,
174  const CuMatrixBase<BaseFloat> &out_diff,
175  CuMatrixBase<BaseFloat> *in_diff) = 0;
176 
178  protected:
180  virtual void InitData(std::istream &is) { }
181 
183  virtual void ReadData(std::istream &is, bool binary) { }
184 
186  virtual void WriteData(std::ostream &os, bool binary) const { }
187 
189  protected:
192 
194  private:
197  ComponentType t, int32 input_dim, int32 output_dim
198  );
199 };
200 
201 
209  public:
210  UpdatableComponent(int32 input_dim, int32 output_dim):
211  Component(input_dim, output_dim),
212  learn_rate_coef_(1.0),
213  bias_learn_rate_coef_(1.0)
214  { }
215 
217  { }
218 
220  bool IsUpdatable() const {
221  return true;
222  }
223 
225  virtual int32 NumParams() const = 0;
226 
228  virtual void GetGradient(VectorBase<BaseFloat> *gradient) const = 0;
229 
231  virtual void GetParams(VectorBase<BaseFloat> *params) const = 0;
232 
234  virtual void SetParams(const VectorBase<BaseFloat> &params) = 0;
235 
237  virtual void Update(const CuMatrixBase<BaseFloat> &input,
238  const CuMatrixBase<BaseFloat> &diff) = 0;
239 
241  virtual void SetTrainOptions(const NnetTrainOptions &opts) {
242  opts_ = opts;
243  }
244 
247  return opts_;
248  }
249 
251  virtual void SetLearnRateCoef(BaseFloat val) {
252  learn_rate_coef_ = val;
253  }
254 
256  virtual void SetBiasLearnRateCoef(BaseFloat val) {
257  bias_learn_rate_coef_ = val;
258  }
259 
261  virtual void InitData(std::istream &is) = 0;
262 
263  protected:
266 
270 
274 };
275 
276 
282  public:
283  MultistreamComponent(int32 input_dim, int32 output_dim):
284  UpdatableComponent(input_dim, output_dim)
285  { }
286 
287  bool IsMultistream() const {
288  return true;
289  }
290 
291  virtual void SetSeqLengths(const std::vector<int32>& sequence_lengths) {
292  sequence_lengths_ = sequence_lengths;
293  }
294 
295  int32 NumStreams() const {
296  return std::max<int32>(1, sequence_lengths_.size());
297  }
298 
300  virtual void ResetStreams(const std::vector<int32>& stream_reset_flag)
301  { }
302 
303  protected:
304  std::vector<int32> sequence_lengths_;
305 };
306 
307 
308 /*
309  * Inline methods for ::Component,
310  */
312  CuMatrix<BaseFloat> *out) {
313  // Check the dims
314  if (input_dim_ != in.NumCols()) {
315  KALDI_ERR << "Non-matching dims on the input of " << TypeToMarker(GetType())
316  << " component. The input-dim is " << input_dim_
317  << ", the data had " << in.NumCols() << " dims.";
318  }
319  // Allocate target buffer
320  out->Resize(in.NumRows(), output_dim_, kSetZero); // reset
321  // Call the propagation implementation of the component
322  PropagateFnc(in, out);
323 }
324 
326  const CuMatrixBase<BaseFloat> &out,
327  const CuMatrixBase<BaseFloat> &out_diff,
328  CuMatrix<BaseFloat> *in_diff) {
329  // Check the dims,
330  if (OutputDim() != out_diff.NumCols()) {
331  KALDI_ERR << "Non-matching dims! Component output dim " << OutputDim()
332  << ", the dim of output derivatives " << out_diff.NumCols();
333  }
334 
335  int32 num_frames = out_diff.NumRows();
336  KALDI_ASSERT(num_frames == in.NumRows());
337  KALDI_ASSERT(num_frames == out.NumRows());
338 
339  KALDI_ASSERT(InputDim() == in.NumCols());
340  KALDI_ASSERT(OutputDim() == out.NumCols());
341 
342  // Allocate target buffer,
343  KALDI_ASSERT(in_diff != NULL);
344  in_diff->Resize(num_frames, InputDim(), kSetZero); // reset,
345 
346  // Call the 'virtual' backprop function,
347  BackpropagateFnc(in, out, out_diff, in_diff);
348 }
349 
350 
351 } // namespace nnet1
352 } // namespace kaldi
353 
354 
355 #endif // KALDI_NNET_NNET_COMPONENT_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
virtual void SetTrainOptions(const NnetTrainOptions &opts)
Set the training options to the component,.
virtual std::string Info() const
Print some additional info (after <ComponentName> and the dims),.
NnetTrainOptions opts_
Option-class with training hyper-parameters,.
int32 input_dim_
Data members,.
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
bool IsUpdatable() const
Check if contains trainable parameters,.
Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hype...
Component(int32 input_dim, int32 output_dim)
Generic interface of a component,.
void Backpropagate(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward-pass propagation &#39;out_diff&#39; -> &#39;in_diff&#39;.
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
static Component * Init(const std::string &conf_line)
Initialize component from a line in config file,.
static Component * Read(std::istream &is, bool binary)
Read the component from a stream (static method),.
virtual void SetLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient,.
ComponentType
Component type identification mechanism,.
virtual void ReadData(std::istream &is, bool binary)
Reads the component content.
virtual bool IsUpdatable() const
Check if componeny has &#39;Updatable&#39; interface (trainable components),.
MultistreamComponent(int32 input_dim, int32 output_dim)
A pair of type and marker,.
virtual void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
const Component::ComponentType key
virtual void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Optional function to reset the transfer of context (not used for BLSTMs.
virtual void SetBiasLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient for bias,.
void Write(std::ostream &os, bool binary) const
Write the component to a stream,.
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward-pass propagation &#39;in&#39; -> &#39;out&#39;,.
int32 InputDim() const
Get the dimension of the input,.
static ComponentType MarkerToType(const std::string &s)
Converts marker to component type (case insensitive),.
UpdatableComponent(int32 input_dim, int32 output_dim)
virtual Component * Copy() const =0
Copy component (deep copy),.
virtual void WriteData(std::ostream &os, bool binary) const
Writes the component content.
#define KALDI_ERR
Definition: kaldi-error.h:147
const NnetTrainOptions & GetTrainOptions() const
Get the training options from the component,.
virtual bool IsMultistream() const
Check if component has &#39;Recurrent&#39; interface (trainable and recurrent),.
virtual void InitData(std::istream &is)
Virtual interface for initialization and I/O,.
virtual std::string InfoGradient() const
Print some additional info about gradient (after <...> and dims),.
Class MultistreamComponent is an extension of UpdatableComponent for recurrent networks, which are trained with parallel sequences.
int32 output_dim_
Dimension of the output of the Component,.
virtual void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)=0
Abstract interface for propagation/backpropagation.
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
Abstract class, building block of the network.
std::vector< int32 > sequence_lengths_
int32 OutputDim() const
Get the dimension of the output,.
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
virtual void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)=0
Backward pass transformation (to be implemented by descending class...)
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50
static Component * NewComponentOfType(ComponentType t, int32 input_dim, int32 output_dim)
Private members (descending classes cannot call this),.
bool IsMultistream() const
Check if component has &#39;Recurrent&#39; interface (trainable and recurrent),.