nnet-attention-component.cc
Go to the documentation of this file.
1 // nnet3/nnet-attention-component.cc
2 
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 // 2017 Hossein Hadian
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <iterator>
22 #include <sstream>
23 #include <iomanip>
25 #include "nnet3/nnet-parse.h"
27 
28 namespace kaldi {
29 namespace nnet3 {
30 
31 
33  std::stringstream stream;
34  stream << Type() << ", input-dim=" << InputDim()
35  << ", output-dim=" << OutputDim()
36  << ", num-heads=" << num_heads_
37  << ", time-stride=" << time_stride_
38  << ", key-dim=" << key_dim_
39  << ", value-dim=" << value_dim_
40  << ", num-left-inputs=" << num_left_inputs_
41  << ", num-right-inputs=" << num_right_inputs_
42  << ", context-dim=" << context_dim_
43  << ", num-left-inputs-required=" << num_left_inputs_required_
44  << ", num-right-inputs-required=" << num_right_inputs_required_
45  << ", output-context=" << (output_context_ ? "true" : "false")
46  << ", key-scale=" << key_scale_;
47  if (stats_count_ != 0.0) {
48  stream << ", entropy=";
49  for (int32 i = 0; i < entropy_stats_.Dim(); i++)
50  stream << (entropy_stats_(i) / stats_count_) << ',';
51  for (int32 i = 0; i < num_heads_ && i < 5; i++) {
52  stream << " posterior-stats[" << i <<"]=";
53  for (int32 j = 0; j < posterior_stats_.NumCols(); j++)
54  stream << (posterior_stats_(i,j) / stats_count_) << ',';
55  }
56  stream << " stats-count=" << stats_count_;
57  }
58  return stream.str();
59 }
60 
62  const RestrictedAttentionComponent &other):
63  num_heads_(other.num_heads_),
64  key_dim_(other.key_dim_),
65  value_dim_(other.value_dim_),
73  key_scale_(other.key_scale_),
77 
78 
79 
81  num_heads_ = 1;
82  key_dim_ = -1;
83  value_dim_ = -1;
84  num_left_inputs_ = -1;
85  num_right_inputs_ = -1;
86  time_stride_ = 1;
89  output_context_ = true;
90  key_scale_ = -1.0;
91 
92 
93  // mandatory arguments.
94  bool ok = cfl->GetValue("key-dim", &key_dim_) &&
95  cfl->GetValue("value-dim", &value_dim_) &&
96  cfl->GetValue("num-left-inputs", &num_left_inputs_) &&
97  cfl->GetValue("num-right-inputs", &num_right_inputs_);
98 
99  if (!ok)
100  KALDI_ERR << "All of the values key-dim, value-dim, "
101  "num-left-inputs and num-right-inputs must be defined.";
102  // optional arguments.
103  cfl->GetValue("num-heads", &num_heads_);
104  cfl->GetValue("time-stride", &time_stride_);
105  cfl->GetValue("num-left-inputs-required", &num_left_inputs_required_);
106  cfl->GetValue("num-right-inputs-required", &num_right_inputs_required_);
107  cfl->GetValue("output-context", &output_context_);
108  cfl->GetValue("key-scale", &key_scale_);
109 
110  if (key_scale_ < 0.0) key_scale_ = 1.0 / sqrt(key_dim_);
115 
116  if (num_heads_ <= 0 || key_dim_ <= 0 || value_dim_ <= 0 ||
117  num_left_inputs_ < 0 || num_right_inputs_ < 0 ||
121  time_stride_ <= 0)
122  KALDI_ERR << "Config line contains invalid values: "
123  << cfl->WholeLine();
124  stats_count_ = 0.0;
126  Check();
127 }
128 
129 
130 
131 void*
133  const CuMatrixBase<BaseFloat> &in,
134  CuMatrixBase<BaseFloat> *out) const {
135  const PrecomputedIndexes *indexes = dynamic_cast<const PrecomputedIndexes*>(
136  indexes_in);
137  KALDI_ASSERT(indexes != NULL &&
138  in.NumRows() == indexes->io.num_t_in * indexes->io.num_images &&
139  out->NumRows() == indexes->io.num_t_out * indexes->io.num_images);
140 
141 
142  Memo *memo = new Memo();
143  memo->c.Resize(out->NumRows(), context_dim_ * num_heads_);
144 
145  int32 query_dim = key_dim_ + context_dim_;
146  int32 input_dim_per_head = key_dim_ + value_dim_ + query_dim,
147  output_dim_per_head = value_dim_ + (output_context_ ? context_dim_ : 0);
148  for (int32 h = 0; h < num_heads_; h++) {
149  CuSubMatrix<BaseFloat> in_part(in, 0, in.NumRows(),
150  h * input_dim_per_head, input_dim_per_head),
151  c_part(memo->c, 0, out->NumRows(),
153  out_part(*out, 0, out->NumRows(),
154  h * output_dim_per_head, output_dim_per_head);
155  PropagateOneHead(indexes->io, in_part, &c_part, &out_part);
156  }
157  return static_cast<void*>(memo);
158 }
159 
162  const CuMatrixBase<BaseFloat> &in,
164  CuMatrixBase<BaseFloat> *out) const {
165  int32 query_dim = key_dim_ + context_dim_,
166  full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
167  KALDI_ASSERT(in.NumRows() == io.num_images * io.num_t_in &&
168  out->NumRows() == io.num_images * io.num_t_out &&
169  out->NumCols() == full_value_dim &&
170  in.NumCols() == (key_dim_ + value_dim_ + query_dim) &&
171  io.t_step_in == io.t_step_out &&
172  (io.start_t_out - io.start_t_in) % io.t_step_in == 0);
173 
174  // 'steps_left_context' is the number of time-steps the input has on the left
175  // that don't appear in the output.
176  int32 steps_left_context = (io.start_t_out - io.start_t_in) / io.t_step_in,
177  rows_left_context = steps_left_context * io.num_images;
178  KALDI_ASSERT(rows_left_context >= 0);
179 
180  // 'queries' contains the queries. We don't use all rows of the input
181  // queries; only the rows that correspond to the time-indexes at the
182  // output, i.e. we exclude the left-context and right-context.
183  // 'out'; the remaining rows of 'in' that we didn't select correspond to left
184  // and right temporal context.
185  CuSubMatrix<BaseFloat> queries(in, rows_left_context, out->NumRows(),
186  key_dim_ + value_dim_, query_dim);
187  // 'keys' contains the keys; note, these are not extended with
188  // context information; that happens further in.
189  CuSubMatrix<BaseFloat> keys(in, 0, in.NumRows(), 0, key_dim_);
190 
191  // 'values' contains the values which we will interpolate.
192  // these don't contain the context information; that will be added
193  // later if output_context_ == true.
194  CuSubMatrix<BaseFloat> values(in, 0, in.NumRows(), key_dim_, value_dim_);
195 
196  attention::AttentionForward(key_scale_, keys, queries, values, c, out);
197 }
198 
199 
201  const CuMatrixBase<BaseFloat> &, // in_value
202  const CuMatrixBase<BaseFloat> &, // out_value
203  void *memo_in) {
204  const Memo *memo = static_cast<const Memo*>(memo_in);
205  KALDI_ASSERT(memo != NULL);
206  if (entropy_stats_.Dim() != num_heads_) {
209  stats_count_ = 0.0;
210  }
211  const CuMatrix<BaseFloat> &c = memo->c;
212  if (RandInt(0, 2) == 0)
213  return; // only actually store the stats for one in three minibatches, to
214  // save time.
215 
216  { // first get the posterior stats.
218  c_sum.AddRowSumMat(1.0, c, 0.0);
219  // view the vector as a matrix.
220  CuSubMatrix<BaseFloat> c_sum_as_mat(c_sum.Data(), num_heads_,
222  CuMatrix<double> c_sum_as_mat_dbl(c_sum_as_mat);
223  posterior_stats_.AddMat(1.0, c_sum_as_mat_dbl);
225  }
226  { // now get the entropy stats.
227  CuMatrix<BaseFloat> log_c(c);
228  log_c.ApplyFloor(1.0e-20);
229  log_c.ApplyLog();
231  dot_prod.AddDiagMatMat(-1.0, c, kTrans, log_c, kNoTrans, 0.0);
232  // dot_prod is the sum over the matrix's rows (which correspond
233  // to heads, and context positions), of - c * log(c), which is
234  // part of the entropy. To get the actual contribution to the
235  // entropy, we have to sum 'dot_prod' over blocks of
236  // size 'context_dim_'; that gives us the entropy contribution
237  // per head. We'd have to divide by c.NumRows() to get the
238  // actual entropy, but that's reflected in stats_count_.
239  CuSubMatrix<BaseFloat> entropy_mat(dot_prod.Data(), num_heads_,
241  CuVector<BaseFloat> entropy_vec(num_heads_);
242  entropy_vec.AddColSumMat(1.0, entropy_mat);
243  Vector<double> entropy_vec_dbl(entropy_vec);
244  entropy_stats_.AddVec(1.0, entropy_vec_dbl);
245  }
246  stats_count_ += c.NumRows();
247 }
248 
252  stats_count_ = 0.0;
253 }
254 
256  entropy_stats_.Scale(scale);
257  posterior_stats_.Scale(scale);
258  stats_count_ *= scale;
259 }
260 
262  const RestrictedAttentionComponent *other =
263  dynamic_cast<const RestrictedAttentionComponent*>(&other_in);
264  KALDI_ASSERT(other != NULL);
265  if (entropy_stats_.Dim() == 0 && other->entropy_stats_.Dim() != 0)
267  if (posterior_stats_.NumRows() == 0 && other->posterior_stats_.NumRows() != 0)
269  if (other->entropy_stats_.Dim() != 0)
270  entropy_stats_.AddVec(alpha, other->entropy_stats_);
271  if (other->posterior_stats_.NumRows() != 0)
273  stats_count_ += alpha * other->stats_count_;
274 }
275 
276 
278  KALDI_ASSERT(num_heads_ > 0 && key_dim_ > 0 && value_dim_ > 0 &&
279  num_left_inputs_ >= 0 && num_right_inputs_ >= 0 &&
281  time_stride_ > 0 &&
287  key_scale_ > 0.0 && key_scale_ <= 1.0 &&
288  stats_count_ >= 0.0);
289 }
290 
291 
293  const std::string &debug_info,
294  const ComponentPrecomputedIndexes *indexes_in,
295  const CuMatrixBase<BaseFloat> &in_value,
296  const CuMatrixBase<BaseFloat> &, // out_value
297  const CuMatrixBase<BaseFloat> &out_deriv,
298  void *memo_in,
299  Component *to_update_in,
300  CuMatrixBase<BaseFloat> *in_deriv) const {
301  NVTX_RANGE("RestrictedAttentionComponent::Backprop");
302  const PrecomputedIndexes *indexes =
303  dynamic_cast<const PrecomputedIndexes*>(indexes_in);
304  KALDI_ASSERT(indexes != NULL);
305  Memo *memo = static_cast<Memo*>(memo_in);
306  KALDI_ASSERT(memo != NULL);
308  KALDI_ASSERT(indexes != NULL &&
309  in_value.NumRows() == io.num_t_in * io.num_images &&
310  out_deriv.NumRows() == io.num_t_out * io.num_images &&
311  in_deriv != NULL && SameDim(in_value, *in_deriv));
312 
313  const CuMatrix<BaseFloat> &c = memo->c;
314 
315  int32 query_dim = key_dim_ + context_dim_,
316  input_dim_per_head = key_dim_ + value_dim_ + query_dim,
317  output_dim_per_head = value_dim_ + (output_context_ ? context_dim_ : 0);
318 
319  for (int32 h = 0; h < num_heads_; h++) {
321  in_value_part(in_value, 0, in_value.NumRows(),
322  h * input_dim_per_head, input_dim_per_head),
323  c_part(c, 0, out_deriv.NumRows(),
325  out_deriv_part(out_deriv, 0, out_deriv.NumRows(),
326  h * output_dim_per_head, output_dim_per_head),
327  in_deriv_part(*in_deriv, 0, in_value.NumRows(),
328  h * input_dim_per_head, input_dim_per_head);
329  BackpropOneHead(io, in_value_part, c_part, out_deriv_part,
330  &in_deriv_part);
331  }
332 }
333 
334 
337  const CuMatrixBase<BaseFloat> &in_value,
338  const CuMatrixBase<BaseFloat> &c,
339  const CuMatrixBase<BaseFloat> &out_deriv,
340  CuMatrixBase<BaseFloat> *in_deriv) const {
341  // the easiest way to understand this is to compare it with PropagateOneHead().
342  int32 query_dim = key_dim_ + context_dim_,
343  full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
344  KALDI_ASSERT(in_value.NumRows() == io.num_images * io.num_t_in &&
345  out_deriv.NumRows() == io.num_images * io.num_t_out &&
346  out_deriv.NumCols() == full_value_dim &&
347  in_value.NumCols() == (key_dim_ + value_dim_ + query_dim) &&
348  io.t_step_in == io.t_step_out &&
349  (io.start_t_out - io.start_t_in) % io.t_step_in == 0 &&
350  SameDim(in_value, *in_deriv) &&
351  c.NumRows() == out_deriv.NumRows() &&
352  c.NumCols() == context_dim_);
353 
354  // 'steps_left_context' is the number of time-steps the input has on the left
355  // that don't appear in the output.
356  int32 steps_left_context = (io.start_t_out - io.start_t_in) / io.t_step_in,
357  rows_left_context = steps_left_context * io.num_images;
358  KALDI_ASSERT(rows_left_context >= 0);
359 
360 
361  CuSubMatrix<BaseFloat> queries(in_value, rows_left_context, out_deriv.NumRows(),
362  key_dim_ + value_dim_, query_dim),
363  queries_deriv(*in_deriv, rows_left_context, out_deriv.NumRows(),
364  key_dim_ + value_dim_, query_dim),
365  keys(in_value, 0, in_value.NumRows(), 0, key_dim_),
366  keys_deriv(*in_deriv, 0, in_value.NumRows(), 0, key_dim_),
367  values(in_value, 0, in_value.NumRows(), key_dim_, value_dim_),
368  values_deriv(*in_deriv, 0, in_value.NumRows(), key_dim_, value_dim_);
369 
370  attention::AttentionBackward(key_scale_, keys, queries, values, c, out_deriv,
371  &keys_deriv, &queries_deriv, &values_deriv);
372 }
373 
374 
375 
377  std::vector<Index> *input_indexes,
378  std::vector<Index> *output_indexes) const {
379  using namespace time_height_convolution;
380  ConvolutionComputationIo io;
381  GetComputationStructure(*input_indexes, *output_indexes, &io);
382  std::vector<Index> new_input_indexes, new_output_indexes;
383  GetIndexes(*input_indexes, *output_indexes, io,
384  &new_input_indexes, &new_output_indexes);
385  input_indexes->swap(new_input_indexes);
386  output_indexes->swap(new_output_indexes);
387 }
388 
390  const std::vector<Index> &input_indexes,
391  const std::vector<Index> &output_indexes,
393  GetComputationIo(input_indexes, output_indexes, io);
394  // if there was only one output and/or input index (unlikely),
395  // just let the grid periodicity be t_stride_.
396  if (io->t_step_out == 0) io->t_step_out = time_stride_;
397  if (io->t_step_in == 0) io->t_step_in = time_stride_;
398 
399  // We need the grid size on the input and output to be the same, and to divide
400  // t_stride_. If someone is requesting the output more frequently than
401  // t_stride_, then after this change we may end up computing more outputs than
402  // we need, but this is not a configuration that I think is very likely. We
403  // let the grid step be the gcd of the input and output steps, and of
404  // t_stride_.
405  // The next few statements may have the effect of making the grid finer at the
406  // input and output, while having the same start and end point.
407  int32 t_step = Gcd(Gcd(io->t_step_out, io->t_step_in), time_stride_);
408  int32 multiple_out = io->t_step_out / t_step,
409  multiple_in = io->t_step_in / t_step;
410  io->t_step_in = t_step;
411  io->t_step_out = t_step;
412  io->num_t_out = 1 + multiple_out * (io->num_t_out - 1);
413  io->num_t_in = 1 + multiple_in * (io->num_t_in - 1);
414 
415  // Now ensure that the extent of the input has at least
416  // the requested left-context and right context; if
417  // this increases the amount of input, we'll do zero-padding.
418  int32 first_requested_input =
420  first_required_input =
422  last_t_out = io->start_t_out + (io->num_t_out - 1) * t_step,
423  last_t_in = io->start_t_in + (io->num_t_in - 1) * t_step,
424  last_requested_input = last_t_out + (time_stride_ * num_right_inputs_),
425  last_required_input =
426  last_t_out + (time_stride_ * num_right_inputs_required_);
427 
428  // check that we don't have *more* than the requested context,
429  // but that we have at least the required context.
430  KALDI_ASSERT(io->start_t_in >= first_requested_input &&
431  last_t_in <= last_requested_input &&
432  io->start_t_in <= first_required_input &&
433  last_t_in >= last_required_input);
434 
435  // For the inputs that were requested, but not required,
436  // we pad with zeros. We pad the 'io' object, adding these
437  // extra inputs structurally; in runtime they'll be set to zero.
438  io->start_t_in = first_requested_input;
439  io->num_t_in = 1 + (last_requested_input - first_requested_input) / t_step;
440 }
441 
442 void RestrictedAttentionComponent::Write(std::ostream &os, bool binary) const {
443  WriteToken(os, binary, "<RestrictedAttentionComponent>");
444  WriteToken(os, binary, "<NumHeads>");
445  WriteBasicType(os, binary, num_heads_);
446  WriteToken(os, binary, "<KeyDim>");
447  WriteBasicType(os, binary, key_dim_);
448  WriteToken(os, binary, "<ValueDim>");
449  WriteBasicType(os, binary, value_dim_);
450  WriteToken(os, binary, "<NumLeftInputs>");
451  WriteBasicType(os, binary, num_left_inputs_);
452  WriteToken(os, binary, "<NumRightInputs>");
453  WriteBasicType(os, binary, num_right_inputs_);
454  WriteToken(os, binary, "<TimeStride>");
455  WriteBasicType(os, binary, time_stride_);
456  WriteToken(os, binary, "<NumLeftInputsRequired>");
458  WriteToken(os, binary, "<NumRightInputsRequired>");
460  WriteToken(os, binary, "<OutputContext>");
461  WriteBasicType(os, binary, output_context_);
462  WriteToken(os, binary, "<KeyScale>");
463  WriteBasicType(os, binary, key_scale_);
464  WriteToken(os, binary, "<StatsCount>");
465  WriteBasicType(os, binary, stats_count_);
466  WriteToken(os, binary, "<EntropyStats>");
467  entropy_stats_.Write(os, binary);
468  WriteToken(os, binary, "<PosteriorStats>");
469  posterior_stats_.Write(os, binary);
470  WriteToken(os, binary, "</RestrictedAttentionComponent>");
471 }
472 
473 void RestrictedAttentionComponent::Read(std::istream &is, bool binary) {
474  ExpectOneOrTwoTokens(is, binary, "<RestrictedAttentionComponent>",
475  "<NumHeads>");
476  ReadBasicType(is, binary, &num_heads_);
477  ExpectToken(is, binary, "<KeyDim>");
478  ReadBasicType(is, binary, &key_dim_);
479  ExpectToken(is, binary, "<ValueDim>");
480  ReadBasicType(is, binary, &value_dim_);
481  ExpectToken(is, binary, "<NumLeftInputs>");
482  ReadBasicType(is, binary, &num_left_inputs_);
483  ExpectToken(is, binary, "<NumRightInputs>");
484  ReadBasicType(is, binary, &num_right_inputs_);
485  ExpectToken(is, binary, "<TimeStride>");
486  ReadBasicType(is, binary, &time_stride_);
487  ExpectToken(is, binary, "<NumLeftInputsRequired>");
489  ExpectToken(is, binary, "<NumRightInputsRequired>");
491  ExpectToken(is, binary, "<OutputContext>");
492  ReadBasicType(is, binary, &output_context_);
493  ExpectToken(is, binary, "<KeyScale>");
494  ReadBasicType(is, binary, &key_scale_);
495  ExpectToken(is, binary, "<StatsCount>");
496  ReadBasicType(is, binary, &stats_count_);
497  ExpectToken(is, binary, "<EntropyStats>");
498  entropy_stats_.Read(is, binary);
499  ExpectToken(is, binary, "<PosteriorStats>");
500  posterior_stats_.Read(is, binary);
501  ExpectToken(is, binary, "</RestrictedAttentionComponent>");
502 
504 }
505 
506 
508  const MiscComputationInfo &misc_info,
509  const Index &output_index,
510  std::vector<Index> *desired_indexes) const {
511  KALDI_ASSERT(output_index.t != kNoTime);
512  int32 first_time = output_index.t - (time_stride_ * num_left_inputs_),
513  last_time = output_index.t + (time_stride_ * num_right_inputs_);
514  desired_indexes->clear();
515  desired_indexes->resize(context_dim_);
516  int32 n = output_index.n, x = output_index.x,
517  i = 0;
518  for (int32 t = first_time; t <= last_time; t += time_stride_, i++) {
519  (*desired_indexes)[i].n = n;
520  (*desired_indexes)[i].t = t;
521  (*desired_indexes)[i].x = x;
522  }
524 }
525 
526 
528  const MiscComputationInfo &misc_info,
529  const Index &output_index,
530  const IndexSet &input_index_set,
531  std::vector<Index> *used_inputs) const {
532  KALDI_ASSERT(output_index.t != kNoTime);
533  Index index(output_index);
534 
535  if (used_inputs != NULL) {
536  int32 first_time = output_index.t - (time_stride_ * num_left_inputs_),
537  last_time = output_index.t + (time_stride_ * num_right_inputs_);
538  used_inputs->clear();
539  used_inputs->reserve(context_dim_);
540 
541  for (int32 t = first_time; t <= last_time; t += time_stride_) {
542  index.t = t;
543  if (input_index_set(index)) {
544  // This input index is available.
545  used_inputs->push_back(index);
546  } else {
547  // This input index is not available.
548  int32 offset = (t - output_index.t) / time_stride_;
549  if (offset >= -num_left_inputs_required_ &&
550  offset <= num_right_inputs_required_) {
551  used_inputs->clear();
552  return false;
553  }
554  }
555  }
556  // All required time-offsets of the output were computable. -> return true.
557  return true;
558  } else {
559  int32 t = output_index.t,
560  first_time_required = t - (time_stride_ * num_left_inputs_required_),
561  last_time_required = t + (time_stride_ * num_right_inputs_required_);
562  for (int32 t = first_time_required;
563  t <= last_time_required;
564  t += time_stride_) {
565  index.t = t;
566  if (!input_index_set(index))
567  return false;
568  }
569  return true;
570  }
571 }
572 
573 
574 // static
576  const std::vector<std::pair<int32, int32> > &n_x_pairs,
577  int32 t_start, int32 t_step, int32 num_t_values,
578  const std::unordered_set<Index, IndexHasher> &index_set,
579  std::vector<Index> *output_indexes) {
580  output_indexes->resize(static_cast<size_t>(num_t_values) * n_x_pairs.size());
581  std::vector<Index>::iterator out_iter = output_indexes->begin();
582  for (int32 t = t_start; t < t_start + (t_step * num_t_values); t += t_step) {
583  std::vector<std::pair<int32, int32> >::const_iterator
584  iter = n_x_pairs.begin(), end = n_x_pairs.end();
585  for (; iter != end; ++iter) {
586  out_iter->n = iter->first;
587  out_iter->t = t;
588  out_iter->x = iter->second;
589  if (index_set.count(*out_iter) == 0)
590  out_iter->t = kNoTime;
591  ++out_iter;
592  }
593  }
594  KALDI_ASSERT(out_iter == output_indexes->end());
595 }
596 
598  const std::vector<Index> &input_indexes,
599  const std::vector<Index> &output_indexes,
601  std::vector<Index> *new_input_indexes,
602  std::vector<Index> *new_output_indexes) const {
603 
604  std::unordered_set<Index, IndexHasher> input_set, output_set;
605  for (std::vector<Index>::const_iterator iter = input_indexes.begin();
606  iter != input_indexes.end(); ++iter)
607  input_set.insert(*iter);
608  for (std::vector<Index>::const_iterator iter = output_indexes.begin();
609  iter != output_indexes.end(); ++iter)
610  output_set.insert(*iter);
611 
612  std::vector<std::pair<int32, int32> > n_x_pairs;
613  GetNxList(input_indexes, &n_x_pairs); // the n,x pairs at the output will be
614  // identical.
615  KALDI_ASSERT(n_x_pairs.size() == io.num_images);
616  CreateIndexesVector(n_x_pairs, io.start_t_in, io.t_step_in, io.num_t_in,
617  input_set, new_input_indexes);
618  CreateIndexesVector(n_x_pairs, io.start_t_out, io.t_step_out, io.num_t_out,
619  output_set, new_output_indexes);
620 }
621 
623  const MiscComputationInfo &, // misc_info
624  const std::vector<Index> &input_indexes,
625  const std::vector<Index> &output_indexes,
626  bool) // need_backprop
627  const {
629  GetComputationStructure(input_indexes, output_indexes, &(ans->io));
630  if (GetVerboseLevel() >= 2) {
631  // what goes next is just a check.
632  std::vector<Index> new_input_indexes, new_output_indexes;
633  GetIndexes(input_indexes, output_indexes, ans->io,
634  &new_input_indexes, &new_output_indexes);
635  // input_indexes and output_indexes should be the ones that were
636  // output by ReorderIndexes(), so they should already
637  // have gone through the GetComputationStructure()->GetIndexes()
638  // procedure. Applying the same procedure twice is supposed to
639  // give an unchanged results.
640  KALDI_ASSERT(input_indexes == new_input_indexes &&
641  output_indexes == new_output_indexes);
642  }
643  return ans;
644 }
645 
646 
647 
650  return new PrecomputedIndexes(*this);
651 }
652 
654  std::ostream &os, bool binary) const {
655  WriteToken(os, binary, "<RestrictedAttentionComponentPrecomputedIndexes>");
656  WriteToken(os, binary, "<Io>");
657  io.Write(os, binary);
658  WriteToken(os, binary, "</RestrictedAttentionComponentPrecomputedIndexes>");
659 }
660 
662  std::istream &is, bool binary) {
663  ExpectOneOrTwoTokens(is, binary,
664  "<RestrictedAttentionComponentPrecomputedIndexes>",
665  "<Io>");
666  io.Read(is, binary);
667  ExpectToken(is, binary, "</RestrictedAttentionComponentPrecomputedIndexes>");
668 }
669 
670 
671 } // namespace nnet3
672 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void GetIndexes(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const
const std::string WholeLine()
Definition: text-utils.h:230
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
void GetComputationStructure(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to &#39;AttentionForward&#39;, propagating the derivative back to th...
Definition: attention.cc:154
virtual int32 InputDim() const
Returns input-dimension of this component.
void Write(std::ostream &os, bool binary) const
Definition: cu-matrix.cc:502
virtual void Read(std::istream &is, bool binary)
Read function (used after we know the type of the Component); accepts input that is missing the token...
Abstract base-class for neural-net components.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
An abstract representation of a set of Indexes.
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
I Gcd(I m, I n)
Definition: kaldi-math.h:297
void ApplyFloor(Real floor_val)
Definition: cu-matrix.h:451
void BackpropOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &out_deriv, CuMatrixBase< BaseFloat > *in_deriv) const
kaldi::int32 int32
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
Definition: attention.cc:97
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:954
static void CreateIndexesVector(const std::vector< std::pair< int32, int32 > > &n_x_pairs, int32 t_start, int32 t_step, int32 num_t_values, const std::unordered_set< Index, IndexHasher > &index_set, std::vector< Index > *output_indexes)
Utility function used in GetIndexes().
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
virtual int32 OutputDim() const
Returns output-dimension of this component.
Contains component(s) related to attention models.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments &#39;to_update&#39; and &#39;in_deriv&#39; are non-NULL...
time_height_convolution::ConvolutionComputationIo io
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
Definition: text-utils.cc:536
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
void AddDiagMatMat(Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, const CuMatrixBase< Real > &N, MatrixTransposeType transN, Real beta=1.0)
Add the diagonal of a matrix product: *this = diag(M N), assuming the "trans" arguments are both kNoT...
Definition: cu-vector.cc:611
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
void Scale(Real value)
Definition: cu-matrix.cc:644
virtual void InitFromConfig(ConfigLine *cfl)
Initialize, from a ConfigLine object.
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:509
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
struct rnnlm::@11::@12 n
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
#define KALDI_ERR
Definition: kaldi-error.h:147
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
This class is used for a piece of a CuMatrix.
Definition: matrix-common.h:70
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
virtual void Write(std::ostream &os, bool binary) const
Write component to stream.
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
void GetComputationIo(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, ConvolutionComputationIo *io)
This function takes lists of input and output indexes to a computation (e.g.
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing=&#39;a b c&#39; baz="a b c d=&#39;a b&#39; e" and giving you access to the fields, in this case.
Definition: text-utils.h:205
RestrictedAttentionComponent implements an attention model with restricted temporal context...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual void Write(std::ostream &os, bool binary) const
void Read(std::istream &is, bool binary)
I/O functions.
Definition: cu-matrix.cc:494
#define NVTX_RANGE(name)
Definition: cu-common.h:143
Real * Data()
Returns a pointer to the start of the vector&#39;s data.
Definition: cu-vector.h:72
bool GetValue(const std::string &key, std::string *value)
Definition: text-utils.cc:427
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
void PropagateOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *out) const
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
void SetZero()
Set vector to all zeros.
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50
const int kNoTime
Definition: nnet-common.cc:573
void AddRowSumMat(Real alpha, const CuMatrixBase< Real > &mat, Real beta=1.0)
Sum the rows of the matrix, add to vector.
Definition: cu-vector.cc:1277
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
void GetNxList(const std::vector< Index > &indexes, std::vector< std::pair< int32, int32 > > *pairs)
This function outputs a unique, lexicographically sorted list of the pairs of (n, x) values that are ...