#include <nnet-multibasis-component.h>


Public Member Functions | |
| MultiBasisComponent (int32 dim_in, int32 dim_out) | |
| ~MultiBasisComponent () | |
| Component * | Copy () const | 
| Copy component (deep copy),.  More... | |
| ComponentType | GetType () const | 
| Get Type Identification of the component,.  More... | |
| void | InitData (std::istream &is) | 
| Initialize the content of the component by the 'line' from the prototype,.  More... | |
| void | ReadData (std::istream &is, bool binary) | 
| Reads the component content.  More... | |
| void | WriteData (std::ostream &os, bool binary) const | 
| Writes the component content.  More... | |
| Nnet & | GetBasis (int32 id) | 
| const Nnet & | GetBasis (int32 id) const | 
| int32 | NumParams () const | 
| Number of trainable parameters,.  More... | |
| void | GetGradient (VectorBase< BaseFloat > *gradient) const | 
| Get gradient reshaped as a vector,.  More... | |
| void | GetParams (VectorBase< BaseFloat > *params) const | 
| Get the trainable parameters reshaped as a vector,.  More... | |
| void | SetParams (const VectorBase< BaseFloat > ¶ms) | 
| Set the trainable parameters from, reshaped as a vector,.  More... | |
| std::string | Info () const | 
| Print some additional info (after <ComponentName> and the dims),.  More... | |
| std::string | InfoGradient () const | 
| Print some additional info about gradient (after <...> and dims),.  More... | |
| std::string | InfoPropagate () const | 
| std::string | InfoBackPropagate () const | 
| void | PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) | 
| Abstract interface for propagation/backpropagation.  More... | |
| void | BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff) | 
| Backward pass transformation (to be implemented by descending class...)  More... | |
| void | Update (const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff) | 
| Compute gradient and update parameters,.  More... | |
| void | SetTrainOptions (const NnetTrainOptions &opts) | 
| Overriding the default, which was UpdatableComponent::SetTrainOptions(...)  More... | |
| void | SetLearnRateCoef (BaseFloat val) | 
| Overriding the default, which was UpdatableComponent::SetLearnRateCoef(...)  More... | |
| void | SetBiasLearnRateCoef (BaseFloat val) | 
| Overriding the default, which was UpdatableComponent::SetBiasLearnRateCoef(...)  More... | |
  Public Member Functions inherited from UpdatableComponent | |
| UpdatableComponent (int32 input_dim, int32 output_dim) | |
| virtual | ~UpdatableComponent () | 
| bool | IsUpdatable () const | 
| Check if contains trainable parameters,.  More... | |
| const NnetTrainOptions & | GetTrainOptions () const | 
| Get the training options from the component,.  More... | |
  Public Member Functions inherited from Component | |
| Component (int32 input_dim, int32 output_dim) | |
| Generic interface of a component,.  More... | |
| virtual | ~Component () | 
| virtual bool | IsMultistream () const | 
| Check if component has 'Recurrent' interface (trainable and recurrent),.  More... | |
| int32 | InputDim () const | 
| Get the dimension of the input,.  More... | |
| int32 | OutputDim () const | 
| Get the dimension of the output,.  More... | |
| void | Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) | 
| Perform forward-pass propagation 'in' -> 'out',.  More... | |
| void | Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff) | 
| Perform backward-pass propagation 'out_diff' -> 'in_diff'.  More... | |
| void | Write (std::ostream &os, bool binary) const | 
| Write the component to a stream,.  More... | |
Private Attributes | |
| std::vector< Nnet > | nnet_basis_ | 
| The vector of 'basis' networks (output of basis is combined according to the posterior_ from the selector_)  More... | |
| std::vector< CuMatrix< BaseFloat > > | basis_out_ | 
| Nnet | selector_ | 
| Selector network,.  More... | |
| BaseFloat | selector_lr_coef_ | 
| CuMatrix< BaseFloat > | posterior_ | 
| The output of 'selector_',.  More... | |
| Vector< BaseFloat > | posterior_sum_ | 
| BaseFloat | threshold_ | 
| Threshold, applied to posterior_sum_, disables the unused basis,.  More... | |
Additional Inherited Members | |
  Public Types inherited from Component | |
| enum | ComponentType {  kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform, kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent, kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax, kSigmoid, kTanh, kParametricRelu, kDropout, kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice, kCopy, kTranspose, kBlockLinearity, kAddShift, kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent, kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent, kMultiBasisComponent }  | 
| Component type identification mechanism,.  More... | |
  Static Public Member Functions inherited from Component | |
| static const char * | TypeToMarker (ComponentType t) | 
| Converts component type to marker,.  More... | |
| static ComponentType | MarkerToType (const std::string &s) | 
| Converts marker to component type (case insensitive),.  More... | |
| static Component * | Init (const std::string &conf_line) | 
| Initialize component from a line in config file,.  More... | |
| static Component * | Read (std::istream &is, bool binary) | 
| Read the component from a stream (static method),.  More... | |
  Static Public Attributes inherited from Component | |
| static const struct key_value | kMarkerMap [] | 
| The table with pairs of Component types and markers (defined in nnet-component.cc),.  More... | |
  Protected Attributes inherited from UpdatableComponent | |
| NnetTrainOptions | opts_ | 
| Option-class with training hyper-parameters,.  More... | |
| BaseFloat | learn_rate_coef_ | 
| Scalar applied to learning rate for weight matrices (to be used in ::Update method),.  More... | |
| BaseFloat | bias_learn_rate_coef_ | 
| Scalar applied to learning rate for bias (to be used in ::Update method),.  More... | |
  Protected Attributes inherited from Component | |
| int32 | input_dim_ | 
| Data members,.  More... | |
| int32 | output_dim_ | 
| Dimension of the output of the Component,.  More... | |
Definition at line 34 of file nnet-multibasis-component.h.
      
  | 
  inline | 
Definition at line 36 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::Copy().
      
  | 
  inline | 
Definition at line 42 of file nnet-multibasis-component.h.
      
  | 
  inlinevirtual | 
Backward pass transformation (to be implemented by descending class...)
Implements Component.
Definition at line 338 of file nnet-multibasis-component.h.
References CuMatrixBase< Real >::AddDiagVecMat(), CuMatrixBase< Real >::AddMat(), Nnet::Backpropagate(), MultiBasisComponent::basis_out_, CuMatrixBase< Real >::ColRange(), CuMatrixBase< Real >::CopyFromMat(), rnnlm::i, Nnet::InputDim(), kaldi::kNoTrans, kaldi::kTrans, MultiBasisComponent::nnet_basis_, CuMatrixBase< Real >::NumRows(), Component::OutputDim(), MultiBasisComponent::posterior_, MultiBasisComponent::posterior_sum_, CuMatrixBase< Real >::Row(), CuMatrixBase< Real >::Scale(), MultiBasisComponent::selector_, MultiBasisComponent::selector_lr_coef_, MultiBasisComponent::threshold_, and CuMatrix< Real >::Transpose().
      
  | 
  inlinevirtual | 
Copy component (deep copy),.
Implements Component.
Definition at line 45 of file nnet-multibasis-component.h.
References MultiBasisComponent::MultiBasisComponent().
Definition at line 196 of file nnet-multibasis-component.h.
References MultiBasisComponent::nnet_basis_.
Definition at line 197 of file nnet-multibasis-component.h.
References MultiBasisComponent::nnet_basis_.
      
  | 
  inlinevirtual | 
Get gradient reshaped as a vector,.
Implements UpdatableComponent.
Definition at line 207 of file nnet-multibasis-component.h.
References KALDI_ERR.
      
  | 
  inlinevirtual | 
Get the trainable parameters reshaped as a vector,.
Implements UpdatableComponent.
Definition at line 211 of file nnet-multibasis-component.h.
References VectorBase< Real >::Dim(), Nnet::GetParams(), rnnlm::i, KALDI_ASSERT, MultiBasisComponent::nnet_basis_, MultiBasisComponent::NumParams(), VectorBase< Real >::Range(), and MultiBasisComponent::selector_.
      
  | 
  inlinevirtual | 
Get Type Identification of the component,.
Implements Component.
Definition at line 46 of file nnet-multibasis-component.h.
References Component::kMultiBasisComponent.
      
  | 
  inlinevirtual | 
Print some additional info (after <ComponentName> and the dims),.
Reimplemented from Component.
Definition at line 240 of file nnet-multibasis-component.h.
References rnnlm::i, Nnet::Info(), MultiBasisComponent::nnet_basis_, and MultiBasisComponent::selector_.
      
  | 
  inline | 
Definition at line 283 of file nnet-multibasis-component.h.
References rnnlm::i, Nnet::InfoBackPropagate(), MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_sum_, MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.
      
  | 
  inlinevirtual | 
Print some additional info about gradient (after <...> and dims),.
Reimplemented from Component.
Definition at line 253 of file nnet-multibasis-component.h.
References rnnlm::i, Nnet::InfoGradient(), MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_sum_, MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.
      
  | 
  inline | 
Definition at line 268 of file nnet-multibasis-component.h.
References rnnlm::i, Nnet::InfoPropagate(), MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_sum_, MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.
      
  | 
  inlinevirtual | 
Initialize the content of the component by the 'line' from the prototype,.
Implements UpdatableComponent.
Definition at line 48 of file nnet-multibasis-component.h.
References rnnlm::i, Nnet::Init(), Nnet::InputDim(), Component::InputDim(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, MultiBasisComponent::nnet_basis_, Nnet::OutputDim(), Component::OutputDim(), Nnet::Read(), kaldi::ReadBasicType(), kaldi::ReadToken(), MultiBasisComponent::selector_, MultiBasisComponent::selector_lr_coef_, AffineTransform::SetLinearity(), and MatrixBase< Real >::SetUnit().
      
  | 
  inlinevirtual | 
Number of trainable parameters,.
Implements UpdatableComponent.
Definition at line 199 of file nnet-multibasis-component.h.
References rnnlm::i, MultiBasisComponent::nnet_basis_, Nnet::NumParams(), and MultiBasisComponent::selector_.
Referenced by MultiBasisComponent::GetParams(), and MultiBasisComponent::SetParams().
      
  | 
  inlinevirtual | 
Abstract interface for propagation/backpropagation.
Forward pass transformation (to be implemented by descending class...)
Implements Component.
Definition at line 298 of file nnet-multibasis-component.h.
References CuVectorBase< Real >::AddColSumMat(), CuMatrixBase< Real >::AddDiagVecMat(), kaldi::ApproxEqual(), MultiBasisComponent::basis_out_, CuMatrixBase< Real >::ColRange(), rnnlm::i, Nnet::InputDim(), KALDI_ASSERT, kaldi::kNoTrans, MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_, MultiBasisComponent::posterior_sum_, Nnet::Propagate(), MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.
      
  | 
  inlinevirtual | 
Reads the component content.
Reimplemented from Component.
Definition at line 137 of file nnet-multibasis-component.h.
References kaldi::ExpectToken(), rnnlm::i, Nnet::InputDim(), Component::InputDim(), KALDI_ASSERT, KALDI_ERR, MultiBasisComponent::nnet_basis_, Nnet::OutputDim(), Component::OutputDim(), kaldi::Peek(), kaldi::PeekToken(), Nnet::Read(), kaldi::ReadBasicType(), kaldi::ReadToken(), MultiBasisComponent::selector_, and MultiBasisComponent::selector_lr_coef_.
      
  | 
  inlinevirtual | 
Overriding the default, which was UpdatableComponent::SetBiasLearnRateCoef(...)
Reimplemented from UpdatableComponent.
Definition at line 419 of file nnet-multibasis-component.h.
References rnnlm::i, rnnlm::j, MultiBasisComponent::nnet_basis_, and UpdatableComponent::SetBiasLearnRateCoef().
      
  | 
  inlinevirtual | 
Overriding the default, which was UpdatableComponent::SetLearnRateCoef(...)
Reimplemented from UpdatableComponent.
Definition at line 400 of file nnet-multibasis-component.h.
References rnnlm::i, rnnlm::j, MultiBasisComponent::nnet_basis_, and UpdatableComponent::SetLearnRateCoef().
      
  | 
  inlinevirtual | 
Set the trainable parameters from, reshaped as a vector,.
Implements UpdatableComponent.
Definition at line 227 of file nnet-multibasis-component.h.
References rnnlm::i, KALDI_ASSERT, MultiBasisComponent::nnet_basis_, Nnet::NumParams(), MultiBasisComponent::NumParams(), VectorBase< Real >::Range(), MultiBasisComponent::selector_, and Nnet::SetParams().
      
  | 
  inlinevirtual | 
Overriding the default, which was UpdatableComponent::SetTrainOptions(...)
Reimplemented from UpdatableComponent.
Definition at line 389 of file nnet-multibasis-component.h.
References rnnlm::i, MultiBasisComponent::nnet_basis_, MultiBasisComponent::selector_, and Nnet::SetTrainOptions().
      
  | 
  inlinevirtual | 
Compute gradient and update parameters,.
Implements UpdatableComponent.
Definition at line 380 of file nnet-multibasis-component.h.
      
  | 
  inlinevirtual | 
Writes the component content.
Reimplemented from Component.
Definition at line 176 of file nnet-multibasis-component.h.
References rnnlm::i, MultiBasisComponent::nnet_basis_, MultiBasisComponent::selector_, MultiBasisComponent::selector_lr_coef_, Nnet::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().
Definition at line 438 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::BackpropagateFnc(), and MultiBasisComponent::PropagateFnc().
      
  | 
  private | 
The vector of 'basis' networks (output of basis is combined according to the posterior_ from the selector_)
Definition at line 437 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::BackpropagateFnc(), MultiBasisComponent::GetBasis(), MultiBasisComponent::GetParams(), MultiBasisComponent::Info(), MultiBasisComponent::InfoBackPropagate(), MultiBasisComponent::InfoGradient(), MultiBasisComponent::InfoPropagate(), MultiBasisComponent::InitData(), MultiBasisComponent::NumParams(), MultiBasisComponent::PropagateFnc(), MultiBasisComponent::ReadData(), MultiBasisComponent::SetBiasLearnRateCoef(), MultiBasisComponent::SetLearnRateCoef(), MultiBasisComponent::SetParams(), MultiBasisComponent::SetTrainOptions(), and MultiBasisComponent::WriteData().
The output of 'selector_',.
Definition at line 445 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::BackpropagateFnc(), and MultiBasisComponent::PropagateFnc().
Definition at line 446 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::BackpropagateFnc(), MultiBasisComponent::InfoBackPropagate(), MultiBasisComponent::InfoGradient(), MultiBasisComponent::InfoPropagate(), and MultiBasisComponent::PropagateFnc().
      
  | 
  private | 
Selector network,.
Definition at line 441 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::BackpropagateFnc(), MultiBasisComponent::GetParams(), MultiBasisComponent::Info(), MultiBasisComponent::InfoBackPropagate(), MultiBasisComponent::InfoGradient(), MultiBasisComponent::InfoPropagate(), MultiBasisComponent::InitData(), MultiBasisComponent::NumParams(), MultiBasisComponent::PropagateFnc(), MultiBasisComponent::ReadData(), MultiBasisComponent::SetParams(), MultiBasisComponent::SetTrainOptions(), and MultiBasisComponent::WriteData().
      
  | 
  private | 
Definition at line 442 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::BackpropagateFnc(), MultiBasisComponent::InitData(), MultiBasisComponent::ReadData(), and MultiBasisComponent::WriteData().
      
  | 
  private | 
Threshold, applied to posterior_sum_, disables the unused basis,.
Definition at line 449 of file nnet-multibasis-component.h.
Referenced by MultiBasisComponent::BackpropagateFnc(), MultiBasisComponent::InfoBackPropagate(), MultiBasisComponent::InfoGradient(), MultiBasisComponent::InfoPropagate(), and MultiBasisComponent::PropagateFnc().