nnet-attention-component.h
Go to the documentation of this file.
1 // nnet3/nnet-attention-component.h
2 
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_ATTENTION_COMPONENT_H_
21 #define KALDI_NNET3_NNET_ATTENTION_COMPONENT_H_
22 
23 #include "nnet3/nnet-common.h"
26 #include "nnet3/attention.h"
27 #include <iostream>
28 
29 namespace kaldi {
30 namespace nnet3 {
31 
35 
36 
37 
107  public:
108 
109  // The use of this constructor should only precede InitFromConfig()
111 
112  // Copy constructor
114 
115  virtual int32 InputDim() const {
116  // the input is interpreted as being appended blocks one for each head; each
117  // such block is interpreted as (key, value, query).
118  int32 query_dim = key_dim_ + context_dim_;
119  return num_heads_ * (key_dim_ + value_dim_ + query_dim);
120  }
121  virtual int32 OutputDim() const {
122  // the output consists of appended blocks, one for each head; each such
123  // block is is the attention weighted average of the input values, to which
124  // we append softmax encoding of the positions we chose, if output_context_
125  // == true.
126  return num_heads_ * (value_dim_ + (output_context_ ? context_dim_ : 0));
127  }
128  virtual std::string Info() const;
129  virtual void InitFromConfig(ConfigLine *cfl);
130  virtual std::string Type() const { return "RestrictedAttentionComponent"; }
131  virtual int32 Properties() const {
134  }
135  virtual void* Propagate(const ComponentPrecomputedIndexes *indexes,
136  const CuMatrixBase<BaseFloat> &in,
137  CuMatrixBase<BaseFloat> *out) const;
138  virtual void StoreStats(const CuMatrixBase<BaseFloat> &in_value,
139  const CuMatrixBase<BaseFloat> &out_value,
140  void *memo);
141  virtual void Scale(BaseFloat scale);
142  virtual void Add(BaseFloat alpha, const Component &other);
143  virtual void ZeroStats();
144 
145  virtual void Backprop(const std::string &debug_info,
146  const ComponentPrecomputedIndexes *indexes,
147  const CuMatrixBase<BaseFloat> &in_value,
148  const CuMatrixBase<BaseFloat> &out_value,
149  const CuMatrixBase<BaseFloat> &out_deriv,
150  void *memo,
151  Component *to_update,
152  CuMatrixBase<BaseFloat> *in_deriv) const;
153  virtual void Read(std::istream &is, bool binary);
154  virtual void Write(std::ostream &os, bool binary) const;
155  virtual Component* Copy() const {
156  return new RestrictedAttentionComponent(*this);
157  }
158  virtual void DeleteMemo(void *memo) const { delete static_cast<Memo*>(memo); }
159 
160  // Some functions that are only to be reimplemented for GeneralComponents.
161 
162  // This ReorderIndexes function may insert 'blank' indexes (indexes with
163  // t == kNoTime) as well as reordering the indexes. This is allowed
164  // behavior of ReorderIndexes functions.
165  virtual void ReorderIndexes(std::vector<Index> *input_indexes,
166  std::vector<Index> *output_indexes) const;
167 
168  virtual void GetInputIndexes(const MiscComputationInfo &misc_info,
169  const Index &output_index,
170  std::vector<Index> *desired_indexes) const;
171 
172  // This function returns true if at least one of the input indexes used to
173  // compute this output index is computable.
174  virtual bool IsComputable(const MiscComputationInfo &misc_info,
175  const Index &output_index,
176  const IndexSet &input_index_set,
177  std::vector<Index> *used_inputs) const;
178 
180  const MiscComputationInfo &misc_info,
181  const std::vector<Index> &input_indexes,
182  const std::vector<Index> &output_indexes,
183  bool need_backprop) const;
184 
186  public:
189  io(other.io) { }
190  virtual PrecomputedIndexes *Copy() const;
191  virtual void Write(std::ostream &os, bool binary) const;
192  virtual void Read(std::istream &os, bool binary);
193  virtual std::string Type() const {
194  return "RestrictedAttentionComponentPrecomputedIndexes";
195  }
196  virtual ~PrecomputedIndexes() { }
197 
199  };
200 
201  // This is what's returned as the 'memo' from the Propagate() function.
202  struct Memo {
203  // c is of dimension (num_heads_ * num-output-frames) by context_dim_,
204  // where num-output-frames is the number of frames of output the
205  // corresponding Propagate function produces.
206  // Each block of 'num-output-frames' rows of c_t is the
207  // post-softmax matrix of weights.
209  };
210 
211  private:
212 
213  // Does the propagation for one head; this is called for each
214  // head by the top-level Propagate function. Later on we may
215  // figure out a way to avoid doing this sequentially.
216  // 'in' and 'out' are submatrices of the 'in' and 'out' passed
217  // to the top-level Propagate function, and 'c' is a submatrix
218  // of the 'c' matrix in the memo we're creating.
219  //
220  // Assumes 'c' has already been zerooed.
221  void PropagateOneHead(
223  const CuMatrixBase<BaseFloat> &in,
225  CuMatrixBase<BaseFloat> *out) const;
226 
227 
228  // does the backprop for one head; called by Backprop().
229  void BackpropOneHead(
231  const CuMatrixBase<BaseFloat> &in_value,
232  const CuMatrixBase<BaseFloat> &c,
233  const CuMatrixBase<BaseFloat> &out_deriv,
234  CuMatrixBase<BaseFloat> *in_deriv) const;
235 
236  // This function, used in ReorderIndexes() and PrecomputedIndexes(),
237  // works out what grid structure over time we will have at the input
238  // and the output.
239  // Note: it may produce a grid that encompasses more than what was
240  // listed in 'input_indexes' and 'output_indexes'. This is OK.
241  // ReorderIndexes() will add placeholders with t == kNoTime for
242  // padding, and at the input of this component those placeholders
243  // will be zero; at the output the placeholders will be ignored.
245  const std::vector<Index> &input_indexes,
246  const std::vector<Index> &output_indexes,
248 
249  // This function, used in ReorderIndexes(), obtains the indexes with the
250  // correct structure and order (the structure is specified in the 'io' object.
251  // This may involve not just reordering the provided indexes, but padding them
252  // with indexes that have kNoTime as the time.
253  //
254  // Basically the indexes this function outputs form a grid where 't' has the
255  // larger stride than the (n, x) pairs. The number of distinct (n, x) pairs
256  // should equal io.num_images. Where 't' values need to appear in the
257  // new indexes that were not present in the old indexes, they get replaced with
258  // kNoTime.
259  void GetIndexes(
260  const std::vector<Index> &input_indexes,
261  const std::vector<Index> &output_indexes,
263  std::vector<Index> *new_input_indexes,
264  std::vector<Index> *new_output_indexes) const;
265 
272  static void CreateIndexesVector(
273  const std::vector<std::pair<int32, int32> > &n_x_pairs,
274  int32 t_start, int32 t_step, int32 num_t_values,
275  const std::unordered_set<Index, IndexHasher> &index_set,
276  std::vector<Index> *output_indexes);
277 
278 
279  void Check() const;
280 
287  int32 context_dim_; // This derived parameter equals 1 + num_left_inputs_ +
288  // num_right_inputs_.
293 
294  double stats_count_; // Count of frames corresponding to the stats.
295  Vector<double> entropy_stats_; // entropy stats, indexed per head.
296  // (dimension is num_heads_). Divide
297  // by stats_count_ to normalize.
298  CuMatrix<double> posterior_stats_; // stats of posteriors of different
299  // offsets, of dimension num_heads_ *
300  // context_dim_ (num-heads has the
301  // larger stride).
302 };
303 
304 
305 
306 
307 } // namespace nnet3
308 } // namespace kaldi
309 
310 
311 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void GetIndexes(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
void GetComputationStructure(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const
virtual int32 InputDim() const
Returns input-dimension of this component.
virtual void Read(std::istream &is, bool binary)
Read function (used after we know the type of the Component); accepts input that is missing the token...
Abstract base-class for neural-net components.
virtual Component * Copy() const
Copies component (deep copy).
An abstract representation of a set of Indexes.
void BackpropOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &out_deriv, CuMatrixBase< BaseFloat > *in_deriv) const
kaldi::int32 int32
static void CreateIndexesVector(const std::vector< std::pair< int32, int32 > > &n_x_pairs, int32 t_start, int32 t_step, int32 num_t_values, const std::unordered_set< Index, IndexHasher > &index_set, std::vector< Index > *output_indexes)
Utility function used in GetIndexes().
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments &#39;to_update&#39; and &#39;in_deriv&#39; are non-NULL...
time_height_convolution::ConvolutionComputationIo io
virtual void DeleteMemo(void *memo) const
This virtual function only needs to be overwritten by Components that return a non-NULL memo from the...
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
virtual void InitFromConfig(ConfigLine *cfl)
Initialize, from a ConfigLine object.
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
virtual void Write(std::ostream &os, bool binary) const
Write component to stream.
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
Matrix for CUDA computing.
Definition: matrix-common.h:69
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing=&#39;a b c&#39; baz="a b c d=&#39;a b&#39; e" and giving you access to the fields, in this case.
Definition: text-utils.h:205
RestrictedAttentionComponent implements an attention model with restricted temporal context...
virtual void Write(std::ostream &os, bool binary) const
This file contains the lower-level interface for self-attention.
virtual int32 Properties() const
Return bitmask of the component&#39;s properties.
void PropagateOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *out) const
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...