nnet-average-pooling-component.h
Go to the documentation of this file.
1 // nnet/nnet-average-pooling-component.h
2 
3 // Copyright 2014 Brno University of Technology (author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #ifndef KALDI_NNET_NNET_AVERAGE_POOLING_COMPONENT_H_
22 #define KALDI_NNET_NNET_AVERAGE_POOLING_COMPONENT_H_
23 
24 #include <string>
25 #include <vector>
26 
27 #include "nnet/nnet-component.h"
28 #include "nnet/nnet-utils.h"
29 #include "cudamatrix/cu-math.h"
30 
31 namespace kaldi {
32 namespace nnet1 {
33 
41  public:
43  Component(dim_in, dim_out),
44  pool_size_(0),
45  pool_step_(0),
46  pool_stride_(0)
47  { }
48 
50  { }
51 
52  Component* Copy() const { return new AveragePoolingComponent(*this); }
54 
55  void InitData(std::istream &is) {
56  // parse config
57  std::string token;
58  while (is >> std::ws, !is.eof()) {
59  ReadToken(is, false, &token);
60  if (token == "<PoolSize>") ReadBasicType(is, false, &pool_size_);
61  else if (token == "<PoolStep>") ReadBasicType(is, false, &pool_step_);
62  else if (token == "<PoolStride>") ReadBasicType(is, false, &pool_stride_);
63  else KALDI_ERR << "Unknown token " << token << ", a typo in config?"
64  << " (PoolSize|PoolStep|PoolStride)";
65  }
66  // check
67  KALDI_ASSERT(pool_size_ != 0 && pool_step_ != 0 && pool_stride_ != 0);
68  }
69 
70  void ReadData(std::istream &is, bool binary) {
71  // pooling hyperparameters
72  ExpectToken(is, binary, "<PoolSize>");
73  ReadBasicType(is, binary, &pool_size_);
74  ExpectToken(is, binary, "<PoolStep>");
75  ReadBasicType(is, binary, &pool_step_);
76  ExpectToken(is, binary, "<PoolStride>");
77  ReadBasicType(is, binary, &pool_stride_);
78 
79  //
80  // Sanity checks:
81  //
82  // number of patches:
84  int32 num_patches = input_dim_ / pool_stride_;
85  // number of pools:
86  KALDI_ASSERT((num_patches - pool_size_) % pool_step_ == 0);
87  int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
88  // check output dim:
89  KALDI_ASSERT(output_dim_ == num_pools * pool_stride_);
90  //
91  }
92 
93  void WriteData(std::ostream &os, bool binary) const {
94  // pooling hyperparameters
95  WriteToken(os, binary, "<PoolSize>");
96  WriteBasicType(os, binary, pool_size_);
97  WriteToken(os, binary, "<PoolStep>");
98  WriteBasicType(os, binary, pool_step_);
99  WriteToken(os, binary, "<PoolStride>");
100  WriteBasicType(os, binary, pool_stride_);
101  }
102 
105  // useful dims
106  int32 num_patches = input_dim_ / pool_stride_;
107  int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
108 
109  // do the average-pooling (pools indexed by q)
110  for (int32 q = 0; q < num_pools; q++) {
111  // get output buffer of the pool
112  CuSubMatrix<BaseFloat> pool(out->ColRange(q*pool_stride_, pool_stride_));
113  pool.SetZero(); // reset,
114  for (int32 r = 0; r < pool_size_; r++) { // sum
115  int32 p = r + q * pool_step_; // p = input patch
116  pool.AddMat(1.0, in.ColRange(p*pool_stride_, pool_stride_));
117  }
118  pool.Scale(1.0 / pool_size_); // divide by #summands
119  }
120  }
121 
123  const CuMatrixBase<BaseFloat> &out,
124  const CuMatrixBase<BaseFloat> &out_diff,
125  CuMatrixBase<BaseFloat> *in_diff) {
126  // useful dims
127  int32 num_patches = input_dim_ / pool_stride_;
128  int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
129 
130  //
131  // here we note how many diff matrices are summed for each input patch,
132  std::vector<int32> patch_summands(num_patches, 0);
133  // this metainfo will be used to divide diff of patches
134  // used in more than one pool.
135  //
136 
137  in_diff->SetZero(); // reset
138 
139  for (int32 q = 0; q < num_pools; q++) { // sum
140  for (int32 r = 0; r < pool_size_; r++) {
141  int32 p = r + q * pool_step_;
142  CuSubMatrix<BaseFloat> tgt(in_diff->ColRange(p*pool_stride_, pool_stride_));
143  CuSubMatrix<BaseFloat> src(out_diff.ColRange(q*pool_stride_, pool_stride_));
144  tgt.AddMat(1.0, src);
145  patch_summands[p] += 1;
146  }
147  }
148 
149  // divide diff by average-pooling-dim (derivative of averaging)
150  in_diff->Scale(1.0 / pool_size_);
151 
152  // divide diff by #summands (compensate for patches used in more pools)
153  for (int32 p = 0; p < num_patches; p++) {
154  CuSubMatrix<BaseFloat> tgt(in_diff->ColRange(p*pool_stride_, pool_stride_));
155  KALDI_ASSERT(patch_summands[p] > 0); // patch at least in one pool
156  tgt.Scale(1.0/patch_summands[p]);
157  }
158  }
159 
160  private:
161  int32 pool_size_, // input patches used for pooling
162  pool_step_, // shift used for pooling (allow overlapping pools)
163  pool_stride_; // stride used to cut input to a vector of matrices
164 };
165 
166 } // namespace nnet1
167 } // namespace kaldi
168 
169 #endif // KALDI_NNET_NNET_AVERAGE_POOLING_COMPONENT_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 input_dim_
Data members,.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:954
ComponentType
Component type identification mechanism,.
void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
Backward pass transformation (to be implemented by descending class...)
void WriteData(std::ostream &os, bool binary) const
Writes the component content.
void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
Abstract interface for propagation/backpropagation.
void Scale(Real value)
Definition: cu-matrix.cc:644
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:509
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
void InitData(std::istream &is)
Virtual interface for initialization and I/O,.
#define KALDI_ERR
Definition: kaldi-error.h:147
Component * Copy() const
Copy component (deep copy),.
void ReadData(std::istream &is, bool binary)
Reads the component content.
This class is used for a piece of a CuMatrix.
Definition: matrix-common.h:70
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
int32 output_dim_
Dimension of the output of the Component,.
CuSubMatrix< Real > ColRange(const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Definition: cu-matrix.h:665
Matrix for CUDA computing.
Definition: matrix-common.h:69
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
ComponentType GetType() const
Get Type Identification of the component,.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Abstract class, building block of the network.
AveragePoolingComponent : The input/output matrices are split to submatrices with width &#39;pool_stride_...