nnet-component-itf.h
Go to the documentation of this file.
1 // nnet3/nnet-component-itf.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Guoguo Chen
5 // 2015 Xiaohui Zhang
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #ifndef KALDI_NNET3_NNET_COMPONENT_ITF_H_
23 #define KALDI_NNET3_NNET_COMPONENT_ITF_H_
24 
25 #include <iostream>
26 #include "nnet3/nnet-common.h"
27 #include "nnet3/nnet-parse.h"
28 #include "base/kaldi-error.h"
29 
30 namespace kaldi {
31 namespace nnet3 {
32 
33 // enum used to store various binary component properties.
34 // We give it a name ComponentProperties, but don't use this
35 // type for the bitmasks: instead use int32 for this type, e.g.
36 // int32 properties = kSimpleComponent|kBackpropNeedsOutput.
38  kSimpleComponent = 0x001, // true if number of rows of input equals number of rows
39  // of output and this component doesn't care about the indexes
40  // (i.e. it maps each row of input to each row of output without
41  // regard to the index values). Will normally be true.
42  kUpdatableComponent = 0x002, // true if the component has parameters that can
43  // be updated. Components that return this flag
44  // must be dynamic_castable to type
45  // UpdatableComponent (but components of type
46  // UpdatableComponent do not have to return this
47  // flag, e.g. if this instance is not really
48  // updatable).
49  kPropagateInPlace = 0x004, // true if we can do the propagate operation in-place
50  // (input and output matrices are the same).
51  // Note: if doing backprop, you'd also need to check
52  // that the kBackpropNeedsInput property is not true.
53  kPropagateAdds = 0x008, // true if the Propagate function adds to, rather
54  // than setting, its output, for non-in-place
55  // propagation. The Component chooses whether to add
56  // or set, and the calling code has to accommodate
57  // it.
58  kReordersIndexes = 0x010, // true if the ReorderIndexes function might reorder
59  // the indexes (otherwise we can skip calling it).
60  // Must not be set for simple components.
61  kBackpropAdds = 0x020, // true if the Backprop function adds to, rather than
62  // setting, the "in_deriv" output for non-in-place
63  // backprop. The Component chooses whether to add or
64  // set, and the calling code has to accommodate it.
65  kBackpropNeedsInput = 0x040, // true if backprop operation needs access to
66  // forward-pass input.
67  kBackpropNeedsOutput = 0x080, // true if backprop operation needs access to
68  // forward-pass output (e.g. true for Sigmoid).
69  kBackpropInPlace = 0x100, // true if we can do the backprop operation in-place
70  // (input and output matrices may be the same).
71  kStoresStats = 0x200, // true if the StoreStats operation stores
72  // statistics e.g. on average node activations and
73  // derivatives of the nonlinearity, (as it does for
74  // Tanh, Sigmoid, ReLU and Softmax).
75  kInputContiguous = 0x400, // true if the component requires its input data (and
76  // input derivatives) to have Stride()== NumCols().
77  kOutputContiguous = 0x800, // true if the component requires its input data (and
78  // output derivatives) to have Stride()== NumCols().
79  kUsesMemo = 0x1000, // true if the component returns a void* pointer from its
80  // Propagate() function that needs to be passed into the
81  // corresponding Backprop function.
82  kRandomComponent = 0x2000 // true if the component has some kind of
83  // randomness, like DropoutComponent (these should
84  // inherit from class RandomComponent.
85 };
86 
87 
88 // This is a base class for a helper-class of class Component, which is used to
89 // store any pre-computed indexes it needs for its forward and backward
90 // computations. For components which are not "Simple" components (i.e. the
91 // kSimpleComponent property is false), and which may therefore "care" about
92 // which index the input and output matrix's rows represent (i.e. about
93 // which "struct Index" each row corresponds to), their CreateIndexes() function
94 // will be called prior to Propagate() and Backprop(), to create an object which
95 // must be a child class of class ComponentPrecomputedIndexes, where they
96 // can store any indexes that they need.
98  public:
99  virtual ComponentPrecomputedIndexes *Copy() const = 0;
100  virtual void Write(std::ostream &os, bool binary) const = 0;
101  virtual void Read(std::istream &os, bool binary) = 0;
102  virtual std::string Type() const = 0;
103  static ComponentPrecomputedIndexes* ReadNew(std::istream &is, bool binary);
104  // cpi stands for component_precomputed_indexes
106  const std::string &cpi_type);
108 };
109 
110 
111 class IndexSet; // Forward declaration; declared in nnet-computation-graph.h.
112 
114 class Component {
115  public:
130  virtual void* Propagate(const ComponentPrecomputedIndexes *indexes,
131  const CuMatrixBase<BaseFloat> &in,
132  CuMatrixBase<BaseFloat> *out) const = 0;
133 
164  virtual void Backprop(const std::string &debug_info,
165  const ComponentPrecomputedIndexes *indexes,
166  const CuMatrixBase<BaseFloat> &in_value,
167  const CuMatrixBase<BaseFloat> &out_value,
168  const CuMatrixBase<BaseFloat> &out_deriv,
169  void *memo,
170  Component *to_update, // may be NULL; may be identical
171  // to "this" or different.
172  CuMatrixBase<BaseFloat> *in_deriv) const = 0;
173 
186  virtual void StoreStats(const CuMatrixBase<BaseFloat> &in_value,
187  const CuMatrixBase<BaseFloat> &out_value,
188  void *memo) { }
189 
195  virtual void ZeroStats() { }
196 
197 
198 
218  virtual void GetInputIndexes(const MiscComputationInfo &misc_info,
219  const Index &output_index,
220  std::vector<Index> *desired_indexes) const;
221 
249  virtual bool IsComputable(const MiscComputationInfo &misc_info,
250  const Index &output_index,
251  const IndexSet &input_index_set,
252  std::vector<Index> *used_inputs) const;
253 
272  virtual void ReorderIndexes(std::vector<Index> *input_indexes,
273  std::vector<Index> *output_indexes) const {}
274 
275 
276 
303  const MiscComputationInfo &misc_info,
304  const std::vector<Index> &input_indexes,
305  const std::vector<Index> &output_indexes,
306  bool need_backprop) const { return NULL; }
307 
308 
311  virtual std::string Type() const = 0;
312 
317  virtual void InitFromConfig(ConfigLine *cfl) = 0;
318 
320  virtual int32 InputDim() const = 0;
321 
323  virtual int32 OutputDim() const = 0;
324 
328  virtual int32 Properties() const = 0;
329 
331  static Component* ReadNew(std::istream &is, bool binary);
332 
334  virtual Component* Copy() const = 0;
335 
338  static Component *NewComponentOfType(const std::string &type);
339 
343  virtual void Read(std::istream &is, bool binary) = 0;
344 
346  virtual void Write(std::ostream &os, bool binary) const = 0;
347 
351  virtual std::string Info() const;
352 
360  virtual void Scale(BaseFloat scale) {};
361 
370  virtual void Add(BaseFloat alpha, const Component &other) {};
371 
376  virtual void DeleteMemo(void *memo) const { KALDI_ASSERT(memo == NULL); }
377 
394  virtual void ConsolidateMemory() { }
395 
396  Component() { }
397 
398  virtual ~Component() { }
399 
400  private:
402 };
403 
404 
405 class RandomComponent: public Component {
406  public:
407  // This function is required in testing code and in other places we need
408  // consistency in the random number generation (e.g. when optimizing
409  // validation-set performance), but check where else we call srand(). You'll
410  // need to call srand prior to making this call.
411  void ResetGenerator() { random_generator_.SeedGpu(); }
412 
413  // Call this with 'true' to set 'test mode' where the behavior is different
414  // from normal mode.
415  void SetTestMode(bool test_mode) { test_mode_ = test_mode; }
416 
417  RandomComponent(): test_mode_(false) { }
418 
420  test_mode_(other.test_mode_) {}
421  protected:
423 
424  // This is true if we want a different behavior for inference from that for
425  // training.
427 };
428 
456  public:
458 
459  // If these defaults are changed, the defaults in
460  // InitLearningRatesFromConfig() should be changed too.
461  UpdatableComponent(): learning_rate_(0.001), learning_rate_factor_(1.0),
462  l2_regularize_(0.0), is_gradient_(false),
463  max_change_(0.0) { }
464 
465  virtual ~UpdatableComponent() { }
466 
470  virtual BaseFloat DotProduct(const UpdatableComponent &other) const = 0;
471 
474  virtual void PerturbParams(BaseFloat stddev) = 0;
475 
478  virtual void SetUnderlyingLearningRate(BaseFloat lrate) {
479  learning_rate_ = lrate * learning_rate_factor_;
480  }
481 
483  virtual void SetActualLearningRate(BaseFloat lrate) { learning_rate_ = lrate; }
484 
487  virtual void SetAsGradient() { learning_rate_ = 1.0; is_gradient_ = true; }
488 
489  virtual BaseFloat LearningRateFactor() { return learning_rate_factor_; }
490 
491  // Sets the learning rate factors to lrate_factor.
492  virtual void SetLearningRateFactor(BaseFloat lrate_factor) {
493  learning_rate_factor_ = lrate_factor;
494  }
495 
496  // Copies the learning-rate, learning-rate-factor, l2-regularize, is-gradient
497  // and max-change values from 'other'.
498  void SetUpdatableConfigs(const UpdatableComponent &other);
499 
502  virtual void FreezeNaturalGradient(bool freeze) { }
503 
505  BaseFloat LearningRate() const { return learning_rate_; }
506 
513  BaseFloat MaxChange() const { return max_change_; }
514 
515  void SetMaxChange(BaseFloat max_change) { max_change_ = max_change; }
516 
522  BaseFloat L2Regularization() const { return l2_regularize_; }
523 
524  void SetL2Regularization(BaseFloat a) { l2_regularize_ = a; }
525 
526  virtual std::string Info() const;
527 
530  virtual int32 NumParameters() const { KALDI_ASSERT(0); return 0; }
531 
535  virtual void Vectorize(VectorBase<BaseFloat> *params) const { KALDI_ASSERT(0); }
537  virtual void UnVectorize(const VectorBase<BaseFloat> &params) {
538  KALDI_ASSERT(0);
539  }
540 
541  protected:
542  // to be called from child classes, extracts any learning rate information
543  // from the config line and sets them appropriately.
544  void InitLearningRatesFromConfig(ConfigLine *cfl);
545 
546  // To be used in child-class Read() functions, this function reads the opening
547  // tag <ThisComponentType> and the learning-rate factor and the learning-rate.
548  //
549  // Its return value may not always be needed to be inspected by calling code;
550  // if there was a token that it read but could not process it returns it, else
551  // it returns "".
552  std::string ReadUpdatableCommon(std::istream &is, bool binary);
553 
554  // To be used in child-class Write() functions, writes the opening
555  // <ThisComponentType> tag and the learning-rate factor (if not 1.0) and the
556  // learning rate;
557  void WriteUpdatableCommon(std::ostream &is, bool binary) const;
558 
561  BaseFloat l2_regularize_;
565  bool is_gradient_;
567  BaseFloat max_change_;
571 
572  private:
573  const UpdatableComponent &operator = (const UpdatableComponent &other); // Disallow.
574 };
575 
576 
577 /* NonlinearComponent is a base-class for things like sigmoid, softmax and
578  ReLU: nonlinearities that don't change the dimension. This base-class
579  takes care of storing statistics on the average activations and derivatives
580  encountered during training, and model initialization and I/O.
581 
582  Supported parameters on the config line:
583 
584  dim Dimension of the input and output of the component.
585  (Caution: for NormalizeComponent, there is a member
586  "add-log-stddev" which if true, will increase the output
587  dim by one, so it will be "dim" plus one.
588 
589  self-repair-scale=0.0 A scale for the self-repair mechanism (which nudges
590  the activation values towards the 'good' regions when a particular
591  dimension of the activations seem to be oversaturated or otherwise
592  unbalanced. This is typically set from the script level to values
593  like 1.0e-04 to 1.0e-05.
594 
595  self-repair-lower-threshold=-1000 A lower threshold for the self-repair mechanism;
596  it will be interpreted in a component-specific way, typically a lower
597  limit on the average derivative or activation below which the
598  self-repair mechanism is activated. -1000 is a special value which
599  will cause a component-specific default to be used.
600 
601  self-repair-upper-threshold=-1000 An upper threshold for the self-repair mechanism;
602  it will be interpreted in a component-specific way, typically an upper
603  limit on the average derivative or activation above which the
604  self-repair mechanism is activated. -1000 is a special value which
605  will cause a component-specific default to be used.
606 
607  block-dim Defaults to dim, but may be any divisor of dim. It affects the
608  self-repair, which will be done while treating the input/output as
609  repeating blocks of size 'block-dim' (e.g. blocks of filters). It allows
610  us to do self-repair on the filter level in CNNs.
611  Currently this only makes a difference for RectifiedLinearComponent.
612 */
614  public:
615 
617  explicit NonlinearComponent(const NonlinearComponent &other);
618 
619  virtual int32 InputDim() const { return dim_; }
620  virtual int32 OutputDim() const { return dim_; }
621 
622  // We implement InitFromConfig at this level and this version is sufficient
623  // for most of the child classes. Note: it's overridden by class
624  // NormalizeComponent.
625  virtual void InitFromConfig(ConfigLine *cfl);
626 
628  virtual void Read(std::istream &is, bool binary);
629 
630  virtual void ZeroStats();
631 
632  virtual std::string Info() const;
633 
635  virtual void Write(std::ostream &os, bool binary) const;
636 
637  virtual void Scale(BaseFloat scale);
638  virtual void Add(BaseFloat alpha, const Component &other);
639 
640  virtual void ConsolidateMemory();
641 
642  // The following functions are unique to NonlinearComponent.
643  // They mostly relate to diagnostics.
644  const CuVector<double> &ValueSum() const { return value_sum_; }
645  const CuVector<double> &DerivSum() const { return deriv_sum_; }
646 
647  double Count() const { return count_; }
648 
649  protected:
650  enum { kUnsetThreshold = -1000 };
651 
652  friend class SigmoidComponent;
653  friend class TanhComponent;
654  friend class SoftmaxComponent;
655  friend class LogSoftmaxComponent;
657 
658  // This function updates the stats "value_sum_", "deriv_sum_", and
659  // count_. (If deriv == NULL, it won't update "deriv_sum_").
660  // It will be called from the Backprop function of child classes.
661  void StoreStatsInternal(const CuMatrixBase<BaseFloat> &out_value,
662  const CuMatrixBase<BaseFloat> *deriv = NULL);
663 
664  // This function may be called from child class members during backprop. It
665  // stores the 'oderiv_sumsq_' stats.
666  void StoreBackpropStats(const CuMatrixBase<BaseFloat> &out_deriv);
667 
668 
669  const NonlinearComponent &operator = (const NonlinearComponent &other); // Disallow.
670 
671  // dim_ is the input dimension (and almost always the output dimension) of the
672  // component.
674  // block_dim_ will normally be the same as dim_, but it may be any nonzero
675  // divisor of dim_; if so, each vector is treated as a number of blocks
676  // appended together, and this affects the stats accumulation and self-repair.
677  // Currently this is only supported for RectifiedLinearComponent.
679  CuVector<double> value_sum_; // stats at the output.
680  CuVector<double> deriv_sum_; // stats of the derivative of the nonlinearity
681  // (only applicable to element-by-element
682  // nonlinearities, not Softmax.
683  // Count corresponding to the stats in 'value_sum_' and 'deriv_sum_'
684  double count_;
685 
686  CuVector<double> oderiv_sumsq_; // Sum-square of the derivative of the
687  // objective function, that we're propagating
688  // back. Accumulated during the backprop;
689  // used for diagnostics.
690  // Count corresponding to the stats in 'oderiv_sumsq_'.
692 
693  // some stats for self-repairing nonlinearities.
696 
697  // some configuration values relating to self-repairing nonlinearities.
701 };
702 
703 } // namespace nnet3
704 } // namespace kaldi
705 
706 
707 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
virtual void Read(std::istream &os, bool binary)=0
void SetTestMode(bool test_mode)
Abstract base-class for neural-net components.
An abstract representation of a set of Indexes.
virtual int32 NumParameters() const
The following new virtual function returns the total dimension of the parameters in this class...
kaldi::int32 int32
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
virtual int32 OutputDim() const
Returns output-dimension of this component.
virtual void DeleteMemo(void *memo) const
This virtual function only needs to be overwritten by Components that return a non-NULL memo from the...
virtual void Vectorize(VectorBase< BaseFloat > *params) const
Turns the parameters into vector form.
virtual ComponentPrecomputedIndexes * Copy() const =0
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
virtual int32 InputDim() const
Returns input-dimension of this component.
RandomComponent(const RandomComponent &other)
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
virtual void FreezeNaturalGradient(bool freeze)
freezes/unfreezes NaturalGradient updates, if applicable (to be overriden by components that use Natu...
BaseFloat MaxChange() const
Returns the per-component max-change value, which is interpreted as the maximum change (in l2 norm) i...
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
BaseFloat L2Regularization() const
Returns the l2 regularization constant, which may be set in any updatable component (usually from the...
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
BaseFloat learning_rate_
learning rate (typically 0.0..0.01)
virtual void Write(std::ostream &os, bool binary) const =0
const CuVector< double > & ValueSum() const
BaseFloat learning_rate_factor_
learning rate factor (normally 1.0, but can be set to another < value so that when < you call SetLear...
virtual void SetActualLearningRate(BaseFloat lrate)
Sets the learning rate directly, bypassing learning_rate_factor_.
BaseFloat DotProduct(const Nnet &nnet1, const Nnet &nnet2)
Returns dot product between two networks of the same structure (calls the DotProduct functions of the...
Definition: nnet-utils.cc:250
virtual void SetAsGradient()
Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring learning_rate_factor_.
Class UpdatableComponent is a Component which has trainable parameters; it extends the interface of C...
static ComponentPrecomputedIndexes * ReadNew(std::istream &is, bool binary)
virtual void SetLearningRateFactor(BaseFloat lrate_factor)
Matrix for CUDA computing.
Definition: matrix-common.h:69
virtual std::string Type() const =0
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
Definition: nnet-utils.cc:1147
virtual void UnVectorize(const VectorBase< BaseFloat > &params)
Converts the parameters from vector form.
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing=&#39;a b c&#39; baz="a b c d=&#39;a b&#39; e" and giving you access to the fields, in this case.
Definition: text-utils.h:205
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual void ConsolidateMemory()
This virtual function relates to memory management, and avoiding fragmentation.
void PerturbParams(BaseFloat stddev, Nnet *nnet)
Calls PerturbParams (with the given stddev) on all updatable components of the nnet.
Definition: nnet-utils.cc:199
const CuVector< double > & DerivSum() const
void SetMaxChange(BaseFloat max_change)
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
virtual void SetUnderlyingLearningRate(BaseFloat lrate)
Sets the learning rate of gradient descent- gets multiplied by learning_rate_factor_.
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
static ComponentPrecomputedIndexes * NewComponentPrecomputedIndexesOfType(const std::string &cpi_type)
CuRand< BaseFloat > random_generator_
BaseFloat LearningRate() const
Gets the learning rate to be used in gradient descent.