nnet-component-itf.cc
Go to the documentation of this file.
1 // nnet3/nnet-component-itf.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Guoguo Chen
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <iterator>
22 #include <sstream>
23 #include <iomanip>
31 #include "nnet3/nnet-parse.h"
33 
34 
35 
36 // \file This file contains some more-generic component code: things in base classes.
37 // See nnet-component.cc for the code of the actual Components.
38 
39 namespace kaldi {
40 namespace nnet3 {
41 
43  bool binary) {
44  std::string token;
45  ReadToken(is, binary, &token); // e.g. "<DistributePrecomputedComponentIndexes>".
46  token.erase(0, 1); // erase "<".
47  token.erase(token.length()-1); // erase ">".
49  if (!ans)
50  KALDI_ERR << "Unknown ComponentPrecomputedIndexes type " << token;
51  ans->Read(is, binary);
52  return ans;
53 }
54 
56  const std::string &cpi_type) {
57  ComponentPrecomputedIndexes *ans = NULL;
58  if (cpi_type == "DistributeComponentPrecomputedIndexes") {
60  } else if (cpi_type == "StatisticsExtractionComponentPrecomputedIndexes") {
62  } else if (cpi_type == "StatisticsPoolingComponentPrecomputedIndexes") {
64  } else if (cpi_type == "BackpropTruncationComponentPrecomputedIndexes") {
66  } else if (cpi_type == "TimeHeightConvolutionComponentPrecomputedIndexes") {
68  } else if (cpi_type == "RestrictedAttentionComponentPrecomputedIndexes") {
70  } else if (cpi_type == "GeneralDropoutComponentPrecomputedIndexes") {
72  } else if (cpi_type == "SpecAugmentTimeMaskComponentPrecomputedIndexes") {
74  } else if (cpi_type == "TdnnComponentPrecomputedIndexes") {
76  }
77  if (ans != NULL) {
78  KALDI_ASSERT(cpi_type == ans->Type());
79  }
80  return ans;
81 }
82 
83 // static
84 Component* Component::ReadNew(std::istream &is, bool binary) {
85  std::string token;
86  ReadToken(is, binary, &token); // e.g. "<SigmoidComponent>".
87  token.erase(0, 1); // erase "<".
88  token.erase(token.length()-1); // erase ">".
89  Component *ans = NewComponentOfType(token);
90  if (!ans)
91  KALDI_ERR << "Unknown component type " << token;
92  ans->Read(is, binary);
93  return ans;
94 }
95 
96 
97 // static
98 Component* Component::NewComponentOfType(const std::string &component_type) {
99  Component *ans = NULL;
100  if (component_type == "SigmoidComponent") {
101  ans = new SigmoidComponent();
102  } else if (component_type == "TanhComponent") {
103  ans = new TanhComponent();
104  } else if (component_type == "SoftmaxComponent") {
105  ans = new SoftmaxComponent();
106  } else if (component_type == "LogSoftmaxComponent") {
107  ans = new LogSoftmaxComponent();
108  } else if (component_type == "RectifiedLinearComponent") {
109  ans = new RectifiedLinearComponent();
110  } else if (component_type == "NormalizeComponent") {
111  ans = new NormalizeComponent();
112  } else if (component_type == "PnormComponent") {
113  ans = new PnormComponent();
114  } else if (component_type == "AffineComponent") {
115  ans = new AffineComponent();
116  } else if (component_type == "LinearComponent") {
117  ans = new LinearComponent();
118  } else if (component_type == "NaturalGradientAffineComponent") {
119  ans = new NaturalGradientAffineComponent();
120  } else if (component_type == "PerElementScaleComponent") {
121  ans = new PerElementScaleComponent();
122  } else if (component_type == "NaturalGradientPerElementScaleComponent") {
124  } else if (component_type == "PerElementOffsetComponent") {
125  ans = new PerElementOffsetComponent();
126  } else if (component_type == "SumGroupComponent") {
127  ans = new SumGroupComponent();
128  } else if (component_type == "FixedAffineComponent") {
129  ans = new FixedAffineComponent();
130  } else if (component_type == "FixedScaleComponent") {
131  ans = new FixedScaleComponent();
132  } else if (component_type == "FixedBiasComponent") {
133  ans = new FixedBiasComponent();
134  } else if (component_type == "NoOpComponent") {
135  ans = new NoOpComponent();
136  } else if (component_type == "ClipGradientComponent") {
137  ans = new ClipGradientComponent();
138  } else if (component_type == "ElementwiseProductComponent") {
139  ans = new ElementwiseProductComponent();
140  } else if (component_type == "ConvolutionComponent") {
141  ans = new ConvolutionComponent();
142  } else if (component_type == "TdnnComponent") {
143  ans = new TdnnComponent();
144  } else if (component_type == "MaxpoolingComponent") {
145  ans = new MaxpoolingComponent();
146  } else if (component_type == "PermuteComponent") {
147  ans = new PermuteComponent();
148  } else if (component_type == "DistributeComponent") {
149  ans = new DistributeComponent();
150  } else if (component_type == "CompositeComponent") {
151  ans = new CompositeComponent();
152  } else if (component_type == "RepeatedAffineComponent") {
153  ans = new RepeatedAffineComponent();
154  } else if (component_type == "BlockAffineComponent") {
155  ans = new BlockAffineComponent();
156  } else if (component_type == "NaturalGradientRepeatedAffineComponent") {
158  } else if (component_type == "StatisticsExtractionComponent") {
159  ans = new StatisticsExtractionComponent();
160  } else if (component_type == "StatisticsPoolingComponent") {
161  ans = new StatisticsPoolingComponent();
162  } else if (component_type == "ConstantFunctionComponent") {
163  ans = new ConstantFunctionComponent();
164  } else if (component_type == "ConstantComponent") {
165  ans = new ConstantComponent();
166  } else if (component_type == "DropoutComponent") {
167  ans = new DropoutComponent();
168  } else if (component_type == "DropoutMaskComponent") {
169  ans = new DropoutMaskComponent();
170  } else if (component_type == "GeneralDropoutComponent") {
171  ans = new GeneralDropoutComponent();
172  } else if (component_type == "SpecAugmentTimeMaskComponent") {
173  ans = new SpecAugmentTimeMaskComponent();
174  } else if (component_type == "BackpropTruncationComponent") {
175  ans = new BackpropTruncationComponent();
176  } else if (component_type == "LstmNonlinearityComponent") {
177  ans = new LstmNonlinearityComponent();
178  } else if (component_type == "BatchNormComponent") {
179  ans = new BatchNormComponent();
180  } else if (component_type == "TimeHeightConvolutionComponent") {
181  ans = new TimeHeightConvolutionComponent();
182  } else if (component_type == "RestrictedAttentionComponent") {
183  ans = new RestrictedAttentionComponent();
184  } else if (component_type == "SumBlockComponent") {
185  ans = new SumBlockComponent();
186  } else if (component_type == "GruNonlinearityComponent") {
187  ans = new GruNonlinearityComponent();
188  } else if (component_type == "OutputGruNonlinearityComponent") {
189  ans = new OutputGruNonlinearityComponent();
190  } else if (component_type == "ScaleAndOffsetComponent") {
191  ans = new ScaleAndOffsetComponent();
192  }
193  if (ans != NULL) {
194  KALDI_ASSERT(component_type == ans->Type());
195  }
196  return ans;
197 }
198 
199 std::string Component::Info() const {
200  std::stringstream stream;
201  stream << Type() << ", input-dim=" << InputDim()
202  << ", output-dim=" << OutputDim();
203  return stream.str();
204 }
205 
207  const Index &output_index,
208  std::vector<Index> *input_indexes) const {
209  input_indexes->resize(1);
210  (*input_indexes)[0] = output_index;
211 }
212 
214  const Index &output_index,
215  const IndexSet &input_index_set,
216  std::vector<Index> *used_inputs) const {
217  // the default Component dependency is for an output index to map directly to
218  // the same input index, which is required to compute the output.
219  if (!input_index_set(output_index))
220  return false;
221  if (used_inputs) {
222  used_inputs->clear();
223  used_inputs->push_back(output_index);
224  }
225  return true;
226 }
227 
228 
230  learning_rate_(other.learning_rate_),
231  learning_rate_factor_(other.learning_rate_factor_),
232  l2_regularize_(other.l2_regularize_),
233  is_gradient_(other.is_gradient_),
234  max_change_(other.max_change_) { }
235 
236 
238  const UpdatableComponent &other) {
242  is_gradient_ = other.is_gradient_;
243  max_change_ = other.max_change_;
244 }
245 
246 // If these defaults are changed, the defaults in the constructor that
247 // takes no arguments should be changed too.
249  learning_rate_ = 0.001;
250  cfl->GetValue("learning-rate", &learning_rate_);
251  learning_rate_factor_ = 1.0;
252  cfl->GetValue("learning-rate-factor", &learning_rate_factor_);
253  max_change_ = 0.0;
254  cfl->GetValue("max-change", &max_change_);
255  l2_regularize_ = 0.0;
256  cfl->GetValue("l2-regularize", &l2_regularize_);
257  if (learning_rate_ < 0.0 || learning_rate_factor_ < 0.0 ||
258  max_change_ < 0.0 || l2_regularize_ < 0.0)
259  KALDI_ERR << "Bad initializer " << cfl->WholeLine();
260 }
261 
262 
263 std::string UpdatableComponent::ReadUpdatableCommon(std::istream &is,
264  bool binary) {
265  std::ostringstream opening_tag;
266  opening_tag << '<' << this->Type() << '>';
267  std::string token;
268  ReadToken(is, binary, &token);
269  if (token == opening_tag.str()) {
270  // if the first token is the opening tag, then
271  // ignore it and get the next tag.
272  ReadToken(is, binary, &token);
273  }
274  if (token == "<LearningRateFactor>") {
275  ReadBasicType(is, binary, &learning_rate_factor_);
276  ReadToken(is, binary, &token);
277  } else {
278  learning_rate_factor_ = 1.0;
279  }
280  if (token == "<IsGradient>") {
281  ReadBasicType(is, binary, &is_gradient_);
282  ReadToken(is, binary, &token);
283  } else {
284  is_gradient_ = false;
285  }
286  if (token == "<MaxChange>") {
287  ReadBasicType(is, binary, &max_change_);
288  ReadToken(is, binary, &token);
289  } else {
290  max_change_ = 0.0;
291  }
292  if (token == "<L2Regularize>") {
293  ReadBasicType(is, binary, &l2_regularize_);
294  ReadToken(is, binary, &token);
295  } else {
296  l2_regularize_ = 0.0;
297  }
298  if (token == "<LearningRate>") {
299  ReadBasicType(is, binary, &learning_rate_);
300  return "";
301  } else {
302  return token;
303  }
304 }
305 
307  bool binary) const {
308  std::ostringstream opening_tag;
309  opening_tag << '<' << this->Type() << '>';
310  std::string token;
311  WriteToken(os, binary, opening_tag.str());
312  if (learning_rate_factor_ != 1.0) {
313  WriteToken(os, binary, "<LearningRateFactor>");
315  }
316  if (is_gradient_) {
317  WriteToken(os, binary, "<IsGradient>");
318  WriteBasicType(os, binary, is_gradient_);
319  }
320  if (max_change_ > 0.0) {
321  WriteToken(os, binary, "<MaxChange>");
322  WriteBasicType(os, binary, max_change_);
323  }
324  if (l2_regularize_ > 0.0) {
325  WriteToken(os, binary, "<L2Regularize>");
326  WriteBasicType(os, binary, l2_regularize_);
327  }
328  WriteToken(os, binary, "<LearningRate>");
329  WriteBasicType(os, binary, learning_rate_);
330 }
331 
332 
333 std::string UpdatableComponent::Info() const {
334  std::stringstream stream;
335  stream << Type() << ", input-dim=" << InputDim()
336  << ", output-dim=" << OutputDim() << ", learning-rate="
337  << LearningRate();
338  if (is_gradient_)
339  stream << ", is-gradient=true";
340  if (l2_regularize_ != 0.0)
341  stream << ", l2-regularize=" << l2_regularize_;
342  if (learning_rate_factor_ != 1.0)
343  stream << ", learning-rate-factor=" << learning_rate_factor_;
344  if (max_change_ > 0.0)
345  stream << ", max-change=" << max_change_;
346  return stream.str();
347 }
348 
350  const CuMatrixBase<BaseFloat> &out_value,
351  const CuMatrixBase<BaseFloat> *deriv) {
352  KALDI_ASSERT(out_value.NumCols() == dim_);
353 
354  // Check we have the correct dimensions.
355  if (value_sum_.Dim() != dim_ ||
356  (deriv != NULL && deriv_sum_.Dim() != dim_)) {
357  if (value_sum_.Dim() != dim_) {
358  value_sum_.Resize(dim_);
359  count_ = 0.0;
360  }
361  if (deriv != NULL && deriv_sum_.Dim() != dim_) {
362  deriv_sum_.Resize(dim_);
363  count_ = 0.0;
364  value_sum_.SetZero();
365  }
366  }
367  count_ += out_value.NumRows();
368  CuVector<BaseFloat> temp(dim_);
369  temp.AddRowSumMat(1.0, out_value, 0.0);
370  value_sum_.AddVec(1.0, temp);
371  if (deriv != NULL) {
372  temp.AddRowSumMat(1.0, *deriv, 0.0);
373  deriv_sum_.AddVec(1.0, temp);
374  }
375 }
376 
378  const CuMatrixBase<BaseFloat> &out_deriv) {
379  // Only store these stats about every 4 minibatches. Make sure to always
380  // store the stats on the very first minibatch, or it would interact badly
381  // with the ConsolidateMemory() code.
382  if (RandInt(0, 3) == 0 && oderiv_count_ != 0)
383  return;
384 
385  KALDI_ASSERT(out_deriv.NumCols() == dim_);
386 
387  // Check we have the correct dimensions.
388  if (oderiv_sumsq_.Dim() != dim_) {
389  oderiv_sumsq_.Resize(dim_);
390  oderiv_count_ = 0.0;
391  }
392  CuVector<BaseFloat> temp(dim_);
393  temp.AddDiagMat2(1.0, out_deriv, kTrans, 0.0);
394  oderiv_sumsq_.AddVec(1.0, temp);
395  oderiv_count_ += out_deriv.NumRows();
396 }
397 
398 
400  value_sum_.SetZero();
401  deriv_sum_.SetZero();
402  oderiv_sumsq_.SetZero();
403  count_ = 0.0;
404  oderiv_count_ = 0.0;
405  num_dims_self_repaired_ = 0.0;
406  num_dims_processed_ = 0.0;
407 }
408 
409 std::string NonlinearComponent::Info() const {
410  std::stringstream stream;
411  stream << Type() << ", dim=" << dim_;
412  if (block_dim_ != dim_)
413  stream << ", block-dim=" << block_dim_;
414  if (self_repair_lower_threshold_ != BaseFloat(kUnsetThreshold))
415  stream << ", self-repair-lower-threshold=" << self_repair_lower_threshold_;
416  if (self_repair_upper_threshold_ != BaseFloat(kUnsetThreshold))
417  stream << ", self-repair-upper-threshold=" << self_repair_upper_threshold_;
418  if (self_repair_scale_ != 0.0)
419  stream << ", self-repair-scale=" << self_repair_scale_;
420  if (count_ > 0 && value_sum_.Dim() == dim_) {
421  stream << ", count=" << std::setprecision(3) << count_
422  << std::setprecision(6);
423  stream << ", self-repaired-proportion="
424  << (num_dims_processed_ > 0 ?
425  num_dims_self_repaired_ / num_dims_processed_ : 0);
426  Vector<double> value_avg_dbl(value_sum_);
427  Vector<BaseFloat> value_avg(value_avg_dbl);
428  value_avg.Scale(1.0 / count_);
429  stream << ", value-avg=" << SummarizeVector(value_avg);
430  if (deriv_sum_.Dim() == dim_) {
431  Vector<double> deriv_avg(deriv_sum_);
432  deriv_avg.Scale(1.0 / count_);
433  stream << ", deriv-avg=" << SummarizeVector(deriv_avg);
434  }
435  }
436  if (oderiv_count_ > 0 && oderiv_sumsq_.Dim() == dim_) {
437  Vector<double> oderiv_rms(oderiv_sumsq_);
438  oderiv_rms.Scale(1.0 / oderiv_count_);
439  // The ApplyMin() is so that the statement after it does not fail even if we
440  // had subtracted models (e.g. in full_progress.*.log).
441  oderiv_rms.ApplyFloor(0.0);
442  oderiv_rms.ApplyPow(0.5);
443  stream << ", oderiv-rms=" << SummarizeVector(oderiv_rms)
444  << ", oderiv-count=" << oderiv_count_;
445  }
446  return stream.str();
447 }
448 
450  value_sum_.Scale(scale);
451  deriv_sum_.Scale(scale);
452  oderiv_sumsq_.Scale(scale);
453  count_ *= scale;
454  oderiv_count_ *= scale;
455  num_dims_self_repaired_ *= scale;
456  num_dims_processed_ *= scale;
457 }
458 
459 void NonlinearComponent::Add(BaseFloat alpha, const Component &other_in) {
460  const NonlinearComponent *other =
461  dynamic_cast<const NonlinearComponent*>(&other_in);
462  KALDI_ASSERT(other != NULL);
463  if (value_sum_.Dim() == 0 && other->value_sum_.Dim() != 0)
464  value_sum_.Resize(other->value_sum_.Dim());
465  if (deriv_sum_.Dim() == 0 && other->deriv_sum_.Dim() != 0)
466  deriv_sum_.Resize(other->deriv_sum_.Dim());
467  if (oderiv_sumsq_.Dim() == 0 && other->oderiv_sumsq_.Dim() != 0)
468  oderiv_sumsq_.Resize(other->oderiv_sumsq_.Dim());
469  if (other->value_sum_.Dim() != 0)
470  value_sum_.AddVec(alpha, other->value_sum_);
471  if (other->deriv_sum_.Dim() != 0)
472  deriv_sum_.AddVec(alpha, other->deriv_sum_);
473  if (other->oderiv_sumsq_.Dim() != 0)
474  oderiv_sumsq_.AddVec(alpha, other->oderiv_sumsq_);
475  count_ += alpha * other->count_;
476  oderiv_count_ += alpha * other->oderiv_count_;
477  num_dims_self_repaired_ += alpha * other->num_dims_self_repaired_;
478  num_dims_processed_ += alpha * other->num_dims_processed_;
479 }
480 
481 void NonlinearComponent::Read(std::istream &is, bool binary) {
482  std::ostringstream ostr_beg, ostr_end;
483  ostr_beg << "<" << Type() << ">"; // e.g. "<SigmoidComponent>"
484  ostr_end << "</" << Type() << ">"; // e.g. "</SigmoidComponent>"
485  ExpectOneOrTwoTokens(is, binary, ostr_beg.str(), "<Dim>");
486  ReadBasicType(is, binary, &dim_); // Read dimension.
487  if (PeekToken(is, binary) == 'B') {
488  ExpectToken(is, binary, "<BlockDim>");
489  ReadBasicType(is, binary, &block_dim_);
490  } else {
491  block_dim_ = dim_;
492  }
493  ExpectToken(is, binary, "<ValueAvg>");
494  value_sum_.Read(is, binary);
495  ExpectToken(is, binary, "<DerivAvg>");
496  deriv_sum_.Read(is, binary);
497  ExpectToken(is, binary, "<Count>");
498  ReadBasicType(is, binary, &count_);
499  if (PeekToken(is, binary) == 'O') {
500  ExpectToken(is, binary, "<OderivRms>");
501  oderiv_sumsq_.Read(is, binary);
502  oderiv_sumsq_.ApplyPow(2.0);
503  ExpectToken(is, binary, "<OderivCount>");
504  ReadBasicType(is, binary, &oderiv_count_);
505  } else {
506  oderiv_count_ = 0.0;
507  oderiv_sumsq_.Resize(0);
508  }
509  value_sum_.Scale(count_);
510  deriv_sum_.Scale(count_);
511  oderiv_sumsq_.Scale(oderiv_count_);
512 
513  std::string token;
514  ReadToken(is, binary, &token);
515  if (token[0] != '<') {
516  // this should happen only rarely, in case we couldn't push back the
517  // '<' to the stream in PeekToken().
518  token = '<' + token;
519  }
520  if (token == "<NumDimsSelfRepaired>") {
521  ReadBasicType(is, binary, &num_dims_self_repaired_);
522  ReadToken(is, binary, &token);
523  }
524  if (token == "<NumDimsProcessed>") {
525  ReadBasicType(is, binary, &num_dims_processed_);
526  ReadToken(is, binary, &token);
527  }
528  if (token == "<SelfRepairLowerThreshold>") {
529  ReadBasicType(is, binary, &self_repair_lower_threshold_);
530  ReadToken(is, binary, &token);
531  }
532  if (token == "<SelfRepairUpperThreshold>") {
533  ReadBasicType(is, binary, &self_repair_upper_threshold_);
534  ReadToken(is, binary, &token);
535  }
536  if (token == "<SelfRepairScale>") {
537  ReadBasicType(is, binary, &self_repair_scale_);
538  ReadToken(is, binary, &token);
539  }
540  if (token != ostr_end.str()) {
541  KALDI_ERR << "Expected token " << ostr_end.str()
542  << ", got " << token;
543  }
544 }
545 
546 void NonlinearComponent::Write(std::ostream &os, bool binary) const {
547  std::ostringstream ostr_beg, ostr_end;
548  ostr_beg << "<" << Type() << ">"; // e.g. "<SigmoidComponent>"
549  ostr_end << "</" << Type() << ">"; // e.g. "</SigmoidComponent>"
550  WriteToken(os, binary, ostr_beg.str());
551  WriteToken(os, binary, "<Dim>");
552  WriteBasicType(os, binary, dim_);
553  if (block_dim_ != dim_) {
554  WriteToken(os, binary, "<BlockDim>");
555  WriteBasicType(os, binary, block_dim_);
556  }
557  // Write the values and derivatives in a count-normalized way, for
558  // greater readability in text form.
559  WriteToken(os, binary, "<ValueAvg>");
560  Vector<BaseFloat> temp(value_sum_);
561  if (count_ != 0.0) temp.Scale(1.0 / count_);
562  temp.Write(os, binary);
563 
564  WriteToken(os, binary, "<DerivAvg>");
565  temp.Resize(deriv_sum_.Dim());
566  temp.CopyFromVec(deriv_sum_);
567  if (count_ != 0.0) temp.Scale(1.0 / count_);
568  temp.Write(os, binary);
569 
570  WriteToken(os, binary, "<Count>");
571  WriteBasicType(os, binary, count_);
572 
573  WriteToken(os, binary, "<OderivRms>");
574  temp.Resize(oderiv_sumsq_.Dim());
575  temp.CopyFromVec(oderiv_sumsq_);
576  if (oderiv_count_ != 0.0) temp.Scale(1.0 / oderiv_count_);
577  // The ApplyMin() is so that the statement after it does not fail even if we
578  // had subtracted models (e.g. in full_progress.*.log).
579  temp.ApplyFloor(0.0);
580  temp.ApplyPow(0.5);
581  temp.Write(os, binary);
582 
583  WriteToken(os, binary, "<OderivCount>");
584  WriteBasicType(os, binary, oderiv_count_);
585 
586  WriteToken(os, binary, "<NumDimsSelfRepaired>");
587  WriteBasicType(os, binary, num_dims_self_repaired_);
588  WriteToken(os, binary, "<NumDimsProcessed>");
589  WriteBasicType(os, binary, num_dims_processed_);
590  if (self_repair_lower_threshold_ != kUnsetThreshold) {
591  WriteToken(os, binary, "<SelfRepairLowerThreshold>");
592  WriteBasicType(os, binary, self_repair_lower_threshold_);
593  }
594  if (self_repair_upper_threshold_ != kUnsetThreshold) {
595  WriteToken(os, binary, "<SelfRepairUpperThreshold>");
596  WriteBasicType(os, binary, self_repair_upper_threshold_);
597  }
598  if (self_repair_scale_ != 0.0) {
599  WriteToken(os, binary, "<SelfRepairScale>");
600  WriteBasicType(os, binary, self_repair_scale_);
601  }
602  WriteToken(os, binary, ostr_end.str());
603 }
604 
606  dim_(-1), block_dim_(-1), count_(0.0), oderiv_count_(0.0),
607  num_dims_self_repaired_(0.0), num_dims_processed_(0.0),
608  self_repair_lower_threshold_(kUnsetThreshold),
609  self_repair_upper_threshold_(kUnsetThreshold),
610  self_repair_scale_(0.0) { }
611 
613  dim_(other.dim_), block_dim_(other.block_dim_),
615  count_(other.count_), oderiv_sumsq_(other.oderiv_sumsq_),
622 
624  bool ok = cfl->GetValue("dim", &dim_);
625  block_dim_ = dim_;
626  cfl->GetValue("block-dim", &block_dim_);
627  cfl->GetValue("self-repair-lower-threshold", &self_repair_lower_threshold_);
628  cfl->GetValue("self-repair-upper-threshold", &self_repair_upper_threshold_);
629  cfl->GetValue("self-repair-scale", &self_repair_scale_);
630  if (!ok || cfl->HasUnusedValues() || dim_ <= 0 ||
631  block_dim_ <= 0 || dim_ % block_dim_ != 0)
632  KALDI_ERR << "Invalid initializer for layer of type "
633  << Type() << ": \"" << cfl->WholeLine() << "\"";
634 }
635 
637  { CuVector<double> temp(value_sum_); value_sum_.Swap(&temp); }
638  { CuVector<double> temp(deriv_sum_); deriv_sum_.Swap(&temp); }
640 }
641 
642 } // namespace nnet3
643 } // namespace kaldi
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const std::string WholeLine()
Definition: text-utils.h:230
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
virtual void Read(std::istream &os, bool binary)=0
Abstract base-class for neural-net components.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
An abstract representation of a set of Indexes.
TdnnComponent is a more memory-efficient alternative to manually splicing several frames of input and...
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void InitLearningRatesFromConfig(ConfigLine *cfl)
void SetUpdatableConfigs(const UpdatableComponent &other)
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
std::string SummarizeVector(const VectorBase< float > &vec)
Returns a string that summarizes a vector fairly succintly, for printing stats in info lines...
Definition: nnet-parse.cc:111
void AddDiagMat2(Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType trans, Real beta)
Add the diagonal of a matrix times itself: *this = diag(M M^T) + beta * *this (if trans == kNoTrans)...
Definition: cu-vector.cc:595
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
virtual int32 OutputDim() const =0
Returns output-dimension of this component.
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
BaseFloat max_change_
configuration value for imposing max-change
PermuteComponent changes the order of the columns (i.e.
Contains component(s) related to attention models.
CompositeComponent is a component representing a sequence of [simple] components. ...
This file contains declarations of components that in one way or another normalize their input: Norma...
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
virtual void Read(std::istream &is, bool binary)=0
Read function (used after we know the type of the Component); accepts input that is missing the token...
virtual void Write(std::ostream &os, bool binary) const
Write component to stream.
FixedScaleComponent applies a fixed per-element scale; it&#39;s similar to the Rescale component in the n...
This file contains declarations of components that are "simple", meaning they don&#39;t care about the in...
SpecAugmentTimeMaskComponent implements the time part of SpecAugment.
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
Definition: text-utils.cc:536
void ApplyFloor(Real floor_val, MatrixIndexT *floored_count=nullptr)
Applies floor to all elements.
Definition: kaldi-vector.h:149
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
NaturalGradientPerElementScaleComponent is like PerElementScaleComponent but it uses a natural gradie...
virtual void ConsolidateMemory()
This virtual function relates to memory management, and avoiding fragmentation.
float BaseFloat
Definition: kaldi-types.h:29
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
virtual void InitFromConfig(ConfigLine *cfl)
Initialize, from a ConfigLine object.
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
SumGroupComponent is used to sum up groups of posteriors.
BaseFloat learning_rate_
learning rate (typically 0.0..0.01)
You can view this as an overflow from nnet-simple-component.h.
std::string ReadUpdatableCommon(std::istream &is, bool binary)
#define KALDI_ERR
Definition: kaldi-error.h:147
static Component * ReadNew(std::istream &is, bool binary)
Read component from stream (works out its type). Dies on error.
This Component takes a larger input-dim than output-dim, where the input-dim must be a multiple of th...
BaseFloat learning_rate_factor_
learning rate factor (normally 1.0, but can be set to another < value so that when < you call SetLear...
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
BaseFloat l2_regularize_
L2 regularization constant.
void Scale(Real alpha)
Multiplies all elements by this constant.
int PeekToken(std::istream &is, bool binary)
PeekToken will return the first character of the next token, or -1 if end of file.
Definition: io-funcs.cc:170
Class UpdatableComponent is a Component which has trainable parameters; it extends the interface of C...
static ComponentPrecomputedIndexes * ReadNew(std::istream &is, bool binary)
virtual std::string Type() const =0
Returns a string such as "SigmoidComponent", describing the type of the object.
GeneralDropoutComponent implements dropout, including a continuous variant where the thing we multipl...
void StoreBackpropStats(const CuMatrixBase< BaseFloat > &out_deriv)
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
virtual std::string Type() const =0
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing=&#39;a b c&#39; baz="a b c d=&#39;a b&#39; e" and giving you access to the fields, in this case.
Definition: text-utils.h:205
RestrictedAttentionComponent implements an attention model with restricted temporal context...
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
bool is_gradient_
True if this component is to be treated as a gradient rather than as parameters.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
PerElementScaleComponent scales each dimension of its input with a separate trainable scale; it&#39;s lik...
void WriteUpdatableCommon(std::ostream &is, bool binary) const
void ApplyPow(Real power)
Take all elements of vector to a power.
Definition: kaldi-vector.h:179
SumBlockComponent sums over blocks of its input: for instance, if you create one with the config "inp...
FixedBiasComponent applies a fixed per-element bias; it&#39;s similar to the AddShift component in the nn...
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
NoOpComponent just duplicates its input.
bool HasUnusedValues() const
Definition: text-utils.cc:510
::MatrixDim Dim() const
Definition: cu-matrix.h:221
bool GetValue(const std::string &key, std::string *value)
Definition: text-utils.cc:427
void StoreStatsInternal(const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > *deriv=NULL)
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
virtual void Read(std::istream &is, bool binary)
We implement Read at this level as it just needs the Type().
void Swap(CuVector< Real > *vec)
Definition: cu-vector.cc:1019
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
virtual int32 InputDim() const =0
Returns input-dimension of this component.
static Component * NewComponentOfType(const std::string &type)
Returns a new Component of the given type e.g.
static ComponentPrecomputedIndexes * NewComponentPrecomputedIndexesOfType(const std::string &cpi_type)
TimeHeightConvolutionComponent implements 2-dimensional convolution where one of the dimensions of co...
BaseFloat LearningRate() const
Gets the learning rate to be used in gradient descent.
This file contains declarations of components that are not "simple", meaning they care about the inde...
WARNING, this component is deprecated in favor of TimeHeightConvolutionComponent, and will be deleted...
void AddRowSumMat(Real alpha, const CuMatrixBase< Real > &mat, Real beta=1.0)
Sum the rows of the matrix, add to vector.
Definition: cu-vector.cc:1277
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:69
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
FixedAffineComponent is an affine transform that is supplied at network initialization time and is no...
This class implements an affine transform using a block diagonal matrix e.g., one whose weight matrix...