nnet-max-pooling-component.h
Go to the documentation of this file.
1 // nnet/nnet-max-pooling-component.h
2 
3 // Copyright 2014 Brno University of Technology (author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #ifndef KALDI_NNET_NNET_MAX_POOLING_COMPONENT_H_
22 #define KALDI_NNET_NNET_MAX_POOLING_COMPONENT_H_
23 
24 #include <string>
25 #include <vector>
26 
27 #include "nnet/nnet-component.h"
28 #include "nnet/nnet-utils.h"
29 #include "cudamatrix/cu-math.h"
30 
31 namespace kaldi {
32 namespace nnet1 {
33 
41  public:
42  MaxPoolingComponent(int32 dim_in, int32 dim_out):
43  Component(dim_in, dim_out),
44  pool_size_(0),
45  pool_step_(0),
46  pool_stride_(0)
47  { }
48 
50  { }
51 
52  Component* Copy() const { return new MaxPoolingComponent(*this); }
54 
55  void InitData(std::istream &is) {
56  // parse config
57  std::string token;
58  while (is >> std::ws, !is.eof()) {
59  ReadToken(is, false, &token);
60  if (token == "<PoolSize>") ReadBasicType(is, false, &pool_size_);
61  else if (token == "<PoolStep>") ReadBasicType(is, false, &pool_step_);
62  else if (token == "<PoolStride>") ReadBasicType(is, false, &pool_stride_);
63  else KALDI_ERR << "Unknown token " << token << ", a typo in config?"
64  << " (PoolSize|PoolStep|PoolStride)";
65  }
66  // check
67  KALDI_ASSERT(pool_size_ != 0 && pool_step_ != 0 && pool_stride_ != 0);
68  }
69 
70  void ReadData(std::istream &is, bool binary) {
71  // pooling hyperparameters
72  ExpectToken(is, binary, "<PoolSize>");
73  ReadBasicType(is, binary, &pool_size_);
74  ExpectToken(is, binary, "<PoolStep>");
75  ReadBasicType(is, binary, &pool_step_);
76  ExpectToken(is, binary, "<PoolStride>");
77  ReadBasicType(is, binary, &pool_stride_);
78 
79  //
80  // Sanity checks:
81  //
82  // number of patches:
84  int32 num_patches = input_dim_ / pool_stride_;
85  // number of pools:
86  KALDI_ASSERT((num_patches - pool_size_) % pool_step_ == 0);
87  int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
88  // check output dim:
89  KALDI_ASSERT(output_dim_ == num_pools * pool_stride_);
90  //
91  }
92 
93  void WriteData(std::ostream &os, bool binary) const {
94  // pooling hyperparameters
95  WriteToken(os, binary, "<PoolSize>");
96  WriteBasicType(os, binary, pool_size_);
97  WriteToken(os, binary, "<PoolStep>");
98  WriteBasicType(os, binary, pool_step_);
99  WriteToken(os, binary, "<PoolStride>");
100  WriteBasicType(os, binary, pool_stride_);
101  }
102 
105  // useful dims
106  int32 num_patches = input_dim_ / pool_stride_;
107  int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
108 
109  // do the max-pooling (pools indexed by q)
110  for (int32 q = 0; q < num_pools; q++) {
111  // get output buffer of the pool
112  CuSubMatrix<BaseFloat> pool(out->ColRange(q*pool_stride_, pool_stride_));
113  pool.Set(-1e20); // reset (large negative value)
114  for (int32 r = 0; r < pool_size_; r++) { // max
115  int32 p = r + q * pool_step_; // p = input patch
116  pool.Max(in.ColRange(p*pool_stride_, pool_stride_));
117  }
118  }
119  }
120 
122  const CuMatrixBase<BaseFloat> &out,
123  const CuMatrixBase<BaseFloat> &out_diff,
124  CuMatrixBase<BaseFloat> *in_diff) {
125  // useful dims
126  int32 num_patches = input_dim_ / pool_stride_;
127  int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
128 
129  //
130  // here we note how many diff matrices are summed for each input patch,
131  std::vector<int32> patch_summands(num_patches, 0);
132  // this metainfo will be used to divide diff of patches
133  // used in more than one pool.
134  //
135 
136  in_diff->SetZero(); // reset
137 
138  for (int32 q = 0; q<num_pools; q++) { // sum
139  for (int32 r = 0; r<pool_size_; r++) {
140  int32 p = r + q * pool_step_; // patch number
141  //
142  CuSubMatrix<BaseFloat> in_p(in.ColRange(p*pool_stride_, pool_stride_));
143  CuSubMatrix<BaseFloat> out_q(out.ColRange(q*pool_stride_, pool_stride_));
144  //
145  CuSubMatrix<BaseFloat> tgt(in_diff->ColRange(p*pool_stride_, pool_stride_));
146  CuMatrix<BaseFloat> src(out_diff.ColRange(q*pool_stride_, pool_stride_));
147 
148  // Only the pool-inputs with 'max-values' are used to back-propagate into,
149  // the rest of derivatives is zeroed-out by a mask.
150  CuMatrix<BaseFloat> mask;
151  in_p.EqualElementMask(out_q, &mask);
152  src.MulElements(mask);
153  tgt.AddMat(1.0, src);
154 
155  patch_summands[p] += 1;
156  }
157  }
158 
159  // divide diff by #summands (compensate for patches used in more pools)
160  for (int32 p = 0; p < num_patches; p++) {
161  CuSubMatrix<BaseFloat> tgt(in_diff->ColRange(p*pool_stride_, pool_stride_));
162  KALDI_ASSERT(patch_summands[p] > 0); // patch at least in one pool
163  tgt.Scale(1.0/patch_summands[p]);
164  }
165  }
166 
167  private:
168  int32 pool_size_, // input patches used for pooling
169  pool_step_, // shift used for pooling (allow overlapping pools)
170  pool_stride_; // stride used to slice input to a vector of matrices
171 };
172 
173 } // namespace nnet1
174 } // namespace kaldi
175 
176 #endif // KALDI_NNET_NNET_MAX_POOLING_COMPONENT_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void WriteData(std::ostream &os, bool binary) const
Writes the component content.
void ReadData(std::istream &is, bool binary)
Reads the component content.
void InitData(std::istream &is)
Virtual interface for initialization and I/O,.
MaxPoolingComponent : The input/output matrices are split to submatrices with width &#39;pool_stride_&#39;...
int32 input_dim_
Data members,.
void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
Backward pass transformation (to be implemented by descending class...)
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
Component * Copy() const
Copy component (deep copy),.
ComponentType GetType() const
Get Type Identification of the component,.
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
ComponentType
Component type identification mechanism,.
void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
Abstract interface for propagation/backpropagation.
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:509
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
#define KALDI_ERR
Definition: kaldi-error.h:147
This class is used for a piece of a CuMatrix.
Definition: matrix-common.h:70
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 output_dim_
Dimension of the output of the Component,.
CuSubMatrix< Real > ColRange(const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Definition: cu-matrix.h:665
MaxPoolingComponent(int32 dim_in, int32 dim_out)
Matrix for CUDA computing.
Definition: matrix-common.h:69
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Abstract class, building block of the network.
void Set(Real value)
Definition: cu-matrix.cc:531
void EqualElementMask(const CuMatrixBase< Real > &mat, CuMatrix< Real > *mask) const
Definition: cu-matrix.cc:3429