nnet-frame-pooling-component.h
Go to the documentation of this file.
1 // nnet/nnet-frame-pooling-component.h
2 
3 // Copyright 2014 Brno University of Technology (author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #ifndef KALDI_NNET_NNET_FRAME_POOLING_COMPONENT_H_
22 #define KALDI_NNET_NNET_FRAME_POOLING_COMPONENT_H_
23 
24 #include <string>
25 #include <vector>
26 #include <algorithm>
27 #include <sstream>
28 
29 #include "nnet/nnet-component.h"
30 #include "nnet/nnet-utils.h"
31 #include "cudamatrix/cu-math.h"
32 
33 namespace kaldi {
34 namespace nnet1 {
35 
44  public:
45  FramePoolingComponent(int32 dim_in, int32 dim_out):
46  UpdatableComponent(dim_in, dim_out),
47  feature_dim_(0),
48  normalize_(false)
49  { }
50 
52  { }
53 
54  Component* Copy() const { return new FramePoolingComponent(*this); }
56 
61  void InitData(std::istream &is) {
62  // temporary, for initialization,
63  std::vector<int32> pool_size;
64  std::vector<int32> central_offset;
65  Vector<BaseFloat> pool_weight;
66  float learn_rate_coef = 0.01;
67  // parse config
68  std::string token;
69  while (is >> std::ws, !is.eof()) {
70  ReadToken(is, false, &token);
71  if (token == "<FeatureDim>") ReadBasicType(is, false, &feature_dim_);
72  else if (token == "<CentralOffset>") ReadIntegerVector(is, false, &central_offset);
73  else if (token == "<PoolSize>") ReadIntegerVector(is, false, &pool_size);
74  else if (token == "<PoolWeight>") pool_weight.Read(is, false);
75  else if (token == "<LearnRateCoef>") ReadBasicType(is, false, &learn_rate_coef);
76  else if (token == "<Normalize>") ReadBasicType(is, false, &normalize_);
77  else KALDI_ERR << "Unknown token " << token << ", a typo in config?"
78  << " (FeatureDim|CentralOffset <vec>|PoolSize <vec>|LearnRateCoef|Normalize)";
79  }
80  // check inputs:
82  KALDI_ASSERT(central_offset.size() > 0);
83  KALDI_ASSERT(central_offset.size() == pool_size.size());
84  // initialize:
85  int32 num_frames = InputDim() / feature_dim_;
86  int32 central_frame = (num_frames -1) / 2;
87  int32 num_pools = central_offset.size();
88  offset_.resize(num_pools);
89  weight_.resize(num_pools);
90  for (int32 p = 0; p < num_pools; p++) {
91  offset_[p] = central_frame + central_offset[p] + std::min(0, pool_size[p]+1);
92  weight_[p].Resize(std::abs(pool_size[p]));
93  weight_[p].Set(1.0/std::abs(pool_size[p]));
94  }
95  learn_rate_coef_ = learn_rate_coef;
96  if (pool_weight.Dim() != 0) {
97  KALDI_LOG << "Initializing from pool-weight vector";
98  int32 num_weights = 0;
99  for (int32 p = 0; p < num_pools; p++) {
100  weight_[p].CopyFromVec(pool_weight.Range(num_weights, weight_[p].Dim()));
101  num_weights += weight_[p].Dim();
102  }
103  KALDI_ASSERT(num_weights == pool_weight.Dim());
104  }
105  // check that offsets are within the splice we had,
106  for (int32 p = 0; p < num_pools; p++) {
107  KALDI_ASSERT(offset_[p] >= 0);
108  KALDI_ASSERT(offset_[p] + weight_[p].Dim() <= num_frames);
109  }
110  }
111 
116  void ReadData(std::istream &is, bool binary) {
117  // get the input dimension before splicing
118  ExpectToken(is, binary, "<FeatureDim>");
119  ReadBasicType(is, binary, &feature_dim_);
120  ExpectToken(is, binary, "<LearnRateCoef>");
121  ReadBasicType(is, binary, &learn_rate_coef_);
122  ExpectToken(is, binary, "<Normalize>");
123  ReadBasicType(is, binary, &normalize_);
124  // read the offsets w.r.t. central frame
125  ExpectToken(is, binary, "<FrameOffset>");
126  ReadIntegerVector(is, binary, &offset_);
127  // read the frame-weights
128  ExpectToken(is, binary, "<FrameWeight>");
129  int32 num_pools = offset_.size();
130  weight_.resize(num_pools);
131  for (int32 p = 0; p < num_pools; p++) {
132  weight_[p].Read(is, binary);
133  }
134  //
135  // Sanity checks:
136  //
139  KALDI_ASSERT(output_dim_ / feature_dim_ == num_pools);
140  KALDI_ASSERT(offset_.size() == weight_.size());
141  // check the shifts don't exceed the splicing
142  int32 total_frame = InputDim() / feature_dim_;
143  for (int32 p = 0; p < num_pools; p++) {
144  KALDI_ASSERT(offset_[p] >= 0);
145  KALDI_ASSERT(offset_[p] + (weight_[p].Dim()-1) < total_frame);
146  }
147  //
148  }
149 
150  void WriteData(std::ostream &os, bool binary) const {
151  WriteToken(os, binary, "<FeatureDim>");
152  WriteBasicType(os, binary, feature_dim_);
153  WriteToken(os, binary, "<LearnRateCoef>");
154  WriteBasicType(os, binary, learn_rate_coef_);
155  WriteToken(os, binary, "<Normalize>");
156  WriteBasicType(os, binary, normalize_);
157  WriteToken(os, binary, "<FrameOffset>");
158  WriteIntegerVector(os, binary, offset_);
159  // write pooling weights of individual frames
160  WriteToken(os, binary, "<FrameWeight>");
161  int32 num_pools = offset_.size();
162  for (int32 p = 0; p < num_pools; p++) {
163  weight_[p].Write(os, binary);
164  }
165  }
166 
167  int32 NumParams() const {
168  int32 ans = 0;
169  for (int32 p = 0; p < weight_.size(); p++) {
170  ans += weight_[p].Dim();
171  }
172  return ans;
173  }
174 
175  void GetGradient(VectorBase<BaseFloat> *gradient) const {
176  KALDI_ERR << "Unimplemented.";
177  }
178 
179  void GetParams(VectorBase<BaseFloat>* params) const {
180  KALDI_ASSERT(params->Dim() == NumParams());
181  int32 offset = 0;
182  for (int32 p = 0; p < weight_.size(); p++) {
183  params->Range(offset, weight_[p].Dim()).CopyFromVec(weight_[p]);
184  offset += weight_[p].Dim();
185  }
186  KALDI_ASSERT(offset == params->Dim());
187  }
188 
189  void SetParams(const VectorBase<BaseFloat>& params) {
190  KALDI_ERR << "Unimplemented.";
191  }
192 
193  std::string Info() const {
194  std::ostringstream oss;
195  oss << "\n (offset,weights) : ";
196  for (int32 p = 0; p < weight_.size(); p++) {
197  oss << "(" << offset_[p] << "," << weight_[p] << "), ";
198  }
199  return oss.str();
200  }
201 
202  std::string InfoGradient() const {
203  std::ostringstream oss;
204  oss << "\n lr-coef " << ToString(learn_rate_coef_);
205  oss << "\n (offset,weights_grad) : ";
206  for (int32 p = 0; p < weight_diff_.size(); p++) {
207  oss << "(" << offset_[p] << ",";
208  // pass the weight vector, remove '\n' as last char
209  oss << weight_diff_[p];
210  oss.seekp(-1, std::ios_base::cur);
211  oss << "), ";
212  }
213  return oss.str();
214  }
215 
218  // check dims
219  KALDI_ASSERT(in.NumCols() % feature_dim_ == 0);
220  KALDI_ASSERT(out->NumCols() % feature_dim_ == 0);
221  // useful dims
222  int32 num_pools = offset_.size();
223  // compute the output pools
224  for (int32 p = 0; p < num_pools; p++) {
225  CuSubMatrix<BaseFloat> tgt(out->ColRange(p*feature_dim_, feature_dim_));
226  tgt.SetZero(); // reset
227  for (int32 i = 0; i < weight_[p].Dim(); i++) {
228  tgt.AddMat(weight_[p](i), in.ColRange((offset_[p]+i) * feature_dim_, feature_dim_));
229  }
230  }
231  }
232 
234  const CuMatrixBase<BaseFloat> &out,
235  const CuMatrixBase<BaseFloat> &out_diff,
236  CuMatrixBase<BaseFloat> *in_diff) {
237  KALDI_ERR << "Unimplemented.";
238  }
239 
240 
241  void Update(const CuMatrixBase<BaseFloat> &input,
242  const CuMatrixBase<BaseFloat> &diff) {
243  // useful dims
244  int32 num_pools = offset_.size();
245  // lazy init
246  if (weight_diff_.size() != num_pools) weight_diff_.resize(num_pools);
247  // get the derivatives
248  for (int32 p = 0; p < num_pools; p++) {
249  weight_diff_[p].Resize(weight_[p].Dim(), kSetZero); // reset
250  for (int32 i = 0; i < weight_[p].Dim(); i++) {
251  // multiply matrices element-wise, and sum to get the derivative
252  CuSubMatrix<BaseFloat> in_frame(
253  input.ColRange((offset_[p]+i) * feature_dim_, feature_dim_)
254  );
255  CuSubMatrix<BaseFloat> diff_frame(
256  diff.ColRange(p * feature_dim_, feature_dim_)
257  );
258  CuMatrix<BaseFloat> mul_elems(in_frame);
259  mul_elems.MulElements(diff_frame);
260  weight_diff_[p](i) = mul_elems.Sum();
261  }
262  }
263  // update
264  for (int32 p = 0; p < num_pools; p++) {
266  }
267  // force to be positive, re-normalize the sum
268  if (normalize_) {
269  for (int32 p = 0; p < num_pools; p++) {
270  weight_[p].ApplyFloor(0.0);
271  weight_[p].Scale(1.0/weight_[p].Sum());
272  }
273  }
274  }
275 
276  private:
277  int32 feature_dim_; // feature dimension before splicing
278  std::vector<int32> offset_; // vector of pooling offsets
280  std::vector<Vector<BaseFloat> > weight_;
282  std::vector<Vector<BaseFloat> > weight_diff_;
283 
284  bool normalize_; // apply normalization after each update
285 };
286 
287 } // namespace nnet1
288 } // namespace kaldi
289 
290 #endif // KALDI_NNET_NNET_FRAME_POOLING_COMPONENT_H_
std::string ToString(const T &t)
Convert basic type to a string (please don&#39;t overuse),.
Definition: nnet-utils.h:52
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
ComponentType GetType() const
Get Type Identification of the component,.
NnetTrainOptions opts_
Option-class with training hyper-parameters,.
Component * Copy() const
Copy component (deep copy),.
int32 input_dim_
Data members,.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hype...
Real Sum() const
Definition: cu-matrix.cc:3012
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
std::string Info() const
Print some additional info (after <ComponentName> and the dims),.
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
ComponentType
Component type identification mechanism,.
void Update(const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff)
Compute gradient and update parameters,.
std::vector< Vector< BaseFloat > > weight_
Vector of pooling weight vectors,.
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:509
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
void MulElements(const CuMatrixBase< Real > &A)
Multiply two matrices elementwise: C = C .* A.
Definition: cu-matrix.cc:667
void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
Backward pass transformation (to be implemented by descending class...)
int32 InputDim() const
Get the dimension of the input,.
FramePoolingComponent(int32 dim_in, int32 dim_out)
#define KALDI_ERR
Definition: kaldi-error.h:147
This class is used for a piece of a CuMatrix.
Definition: matrix-common.h:70
void ReadData(std::istream &is, bool binary)
Here the offsets are w.r.t.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void SetParams(const VectorBase< BaseFloat > &params)
Set the trainable parameters from, reshaped as a vector,.
std::string InfoGradient() const
Print some additional info about gradient (after <...> and dims),.
int32 output_dim_
Dimension of the output of the Component,.
CuSubMatrix< Real > ColRange(const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Definition: cu-matrix.h:665
int32 NumParams() const
Number of trainable parameters,.
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void GetGradient(VectorBase< BaseFloat > *gradient) const
Get gradient reshaped as a vector,.
void InitData(std::istream &is)
Here the offsets are w.r.t.
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
Abstract interface for propagation/backpropagation.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Abstract class, building block of the network.
std::vector< Vector< BaseFloat > > weight_diff_
detivatives of weight vectors,
void GetParams(VectorBase< BaseFloat > *params) const
Get the trainable parameters reshaped as a vector,.
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
FramePoolingComponent : The input/output matrices are split to frames of width &#39;feature_dim_&#39;.
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void WriteData(std::ostream &os, bool binary) const
Writes the component content.
SubVector< Real > Range(const MatrixIndexT o, const MatrixIndexT l)
Returns a sub-vector of a vector (a range of elements).
Definition: kaldi-vector.h:94