#include <nnet-lstm-projected.h>
Public Member Functions | |
LstmProjected (int32 input_dim, int32 output_dim) | |
~LstmProjected () | |
Component * | Copy () const |
Copy component (deep copy),. More... | |
ComponentType | GetType () const |
Get Type Identification of the component,. More... | |
void | InitData (std::istream &is) |
Initialize the content of the component by the 'line' from the prototype,. More... | |
void | ReadData (std::istream &is, bool binary) |
Reads the component content. More... | |
void | WriteData (std::ostream &os, bool binary) const |
Writes the component content. More... | |
int32 | NumParams () const |
Number of trainable parameters,. More... | |
void | GetGradient (VectorBase< BaseFloat > *gradient) const |
Get gradient reshaped as a vector,. More... | |
void | GetParams (VectorBase< BaseFloat > *params) const |
Get the trainable parameters reshaped as a vector,. More... | |
void | SetParams (const VectorBase< BaseFloat > ¶ms) |
Set the trainable parameters from, reshaped as a vector,. More... | |
std::string | Info () const |
Print some additional info (after <ComponentName> and the dims),. More... | |
std::string | InfoGradient () const |
Print some additional info about gradient (after <...> and dims),. More... | |
void | ResetStreams (const std::vector< int32 > &stream_reset_flag) |
TODO: Do we really need this? More... | |
void | PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) |
Abstract interface for propagation/backpropagation. More... | |
void | BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff) |
Backward pass transformation (to be implemented by descending class...) More... | |
void | Update (const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff) |
Compute gradient and update parameters,. More... | |
Public Member Functions inherited from MultistreamComponent | |
MultistreamComponent (int32 input_dim, int32 output_dim) | |
bool | IsMultistream () const |
Check if component has 'Recurrent' interface (trainable and recurrent),. More... | |
virtual void | SetSeqLengths (const std::vector< int32 > &sequence_lengths) |
int32 | NumStreams () const |
Public Member Functions inherited from UpdatableComponent | |
UpdatableComponent (int32 input_dim, int32 output_dim) | |
virtual | ~UpdatableComponent () |
bool | IsUpdatable () const |
Check if contains trainable parameters,. More... | |
virtual void | SetTrainOptions (const NnetTrainOptions &opts) |
Set the training options to the component,. More... | |
const NnetTrainOptions & | GetTrainOptions () const |
Get the training options from the component,. More... | |
virtual void | SetLearnRateCoef (BaseFloat val) |
Set the learn-rate coefficient,. More... | |
virtual void | SetBiasLearnRateCoef (BaseFloat val) |
Set the learn-rate coefficient for bias,. More... | |
Public Member Functions inherited from Component | |
Component (int32 input_dim, int32 output_dim) | |
Generic interface of a component,. More... | |
virtual | ~Component () |
int32 | InputDim () const |
Get the dimension of the input,. More... | |
int32 | OutputDim () const |
Get the dimension of the output,. More... | |
void | Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) |
Perform forward-pass propagation 'in' -> 'out',. More... | |
void | Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff) |
Perform backward-pass propagation 'out_diff' -> 'in_diff'. More... | |
void | Write (std::ostream &os, bool binary) const |
Write the component to a stream,. More... | |
Additional Inherited Members | |
Public Types inherited from Component | |
enum | ComponentType { kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform, kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent, kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax, kSigmoid, kTanh, kParametricRelu, kDropout, kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice, kCopy, kTranspose, kBlockLinearity, kAddShift, kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent, kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent, kMultiBasisComponent } |
Component type identification mechanism,. More... | |
Static Public Member Functions inherited from Component | |
static const char * | TypeToMarker (ComponentType t) |
Converts component type to marker,. More... | |
static ComponentType | MarkerToType (const std::string &s) |
Converts marker to component type (case insensitive),. More... | |
static Component * | Init (const std::string &conf_line) |
Initialize component from a line in config file,. More... | |
static Component * | Read (std::istream &is, bool binary) |
Read the component from a stream (static method),. More... | |
Static Public Attributes inherited from Component | |
static const struct key_value | kMarkerMap [] |
The table with pairs of Component types and markers (defined in nnet-component.cc),. More... | |
Protected Attributes inherited from MultistreamComponent | |
std::vector< int32 > | sequence_lengths_ |
Protected Attributes inherited from UpdatableComponent | |
NnetTrainOptions | opts_ |
Option-class with training hyper-parameters,. More... | |
BaseFloat | learn_rate_coef_ |
Scalar applied to learning rate for weight matrices (to be used in ::Update method),. More... | |
BaseFloat | bias_learn_rate_coef_ |
Scalar applied to learning rate for bias (to be used in ::Update method),. More... | |
Protected Attributes inherited from Component | |
int32 | input_dim_ |
Data members,. More... | |
int32 | output_dim_ |
Dimension of the output of the Component,. More... | |
Definition at line 48 of file nnet-lstm-projected.h.
|
inline |
Definition at line 50 of file nnet-lstm-projected.h.
Referenced by LstmProjected::Copy().
|
inline |
Definition at line 60 of file nnet-lstm-projected.h.
|
inlinevirtual |
Backward pass transformation (to be implemented by descending class...)
Implements Component.
Definition at line 475 of file nnet-lstm-projected.h.
References CuMatrixBase< Real >::AddMatMat(), LstmProjected::backpropagate_buf_, LstmProjected::bias_corr_, LstmProjected::cell_diff_clip_, LstmProjected::cell_dim_, LstmProjected::diff_clip_, Component::input_dim_, kaldi::kNoTrans, kaldi::kSetZero, kaldi::kTrans, NnetTrainOptions::momentum, CuMatrixBase< Real >::NumRows(), MultistreamComponent::NumStreams(), UpdatableComponent::opts_, LstmProjected::peephole_f_c_, LstmProjected::peephole_f_c_corr_, LstmProjected::peephole_i_c_, LstmProjected::peephole_i_c_corr_, LstmProjected::peephole_o_c_, LstmProjected::peephole_o_c_corr_, LstmProjected::proj_dim_, LstmProjected::propagate_buf_, CuMatrixBase< Real >::RowRange(), MultistreamComponent::sequence_lengths_, LstmProjected::w_gifo_r_, LstmProjected::w_gifo_r_corr_, LstmProjected::w_gifo_x_, LstmProjected::w_gifo_x_corr_, LstmProjected::w_r_m_, and LstmProjected::w_r_m_corr_.
|
inlinevirtual |
Copy component (deep copy),.
Implements Component.
Definition at line 63 of file nnet-lstm-projected.h.
References LstmProjected::LstmProjected().
|
inlinevirtual |
Get gradient reshaped as a vector,.
Implements UpdatableComponent.
Definition at line 191 of file nnet-lstm-projected.h.
References LstmProjected::bias_, LstmProjected::bias_corr_, VectorBase< Real >::Dim(), KALDI_ASSERT, LstmProjected::NumParams(), LstmProjected::peephole_f_c_, LstmProjected::peephole_f_c_corr_, LstmProjected::peephole_i_c_, LstmProjected::peephole_i_c_corr_, LstmProjected::peephole_o_c_, LstmProjected::peephole_o_c_corr_, VectorBase< Real >::Range(), LstmProjected::w_gifo_r_, LstmProjected::w_gifo_r_corr_, LstmProjected::w_gifo_x_, LstmProjected::w_gifo_x_corr_, LstmProjected::w_r_m_, and LstmProjected::w_r_m_corr_.
|
inlinevirtual |
Get the trainable parameters reshaped as a vector,.
Implements UpdatableComponent.
Definition at line 220 of file nnet-lstm-projected.h.
References LstmProjected::bias_, VectorBase< Real >::Dim(), KALDI_ASSERT, LstmProjected::NumParams(), LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, VectorBase< Real >::Range(), LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, and LstmProjected::w_r_m_.
|
inlinevirtual |
Get Type Identification of the component,.
Implements Component.
Definition at line 64 of file nnet-lstm-projected.h.
References Component::kLstmProjected.
|
inlinevirtual |
Print some additional info (after <ComponentName> and the dims),.
Reimplemented from Component.
Definition at line 278 of file nnet-lstm-projected.h.
References LstmProjected::bias_, UpdatableComponent::bias_learn_rate_coef_, LstmProjected::cell_clip_, LstmProjected::cell_dim_, LstmProjected::diff_clip_, LstmProjected::grad_clip_, UpdatableComponent::learn_rate_coef_, kaldi::nnet1::MomentStatistics(), LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, kaldi::nnet1::ToString(), LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, and LstmProjected::w_r_m_.
|
inlinevirtual |
Print some additional info about gradient (after <...> and dims),.
Reimplemented from Component.
Definition at line 294 of file nnet-lstm-projected.h.
References LstmProjected::backpropagate_buf_, LstmProjected::bias_corr_, UpdatableComponent::bias_learn_rate_coef_, LstmProjected::cell_clip_, LstmProjected::cell_dim_, LstmProjected::diff_clip_, LstmProjected::grad_clip_, UpdatableComponent::learn_rate_coef_, kaldi::nnet1::MomentStatistics(), LstmProjected::peephole_f_c_corr_, LstmProjected::peephole_i_c_corr_, LstmProjected::peephole_o_c_corr_, LstmProjected::proj_dim_, LstmProjected::propagate_buf_, kaldi::nnet1::ToString(), LstmProjected::w_gifo_r_corr_, LstmProjected::w_gifo_x_corr_, and LstmProjected::w_r_m_corr_.
|
inlinevirtual |
Initialize the content of the component by the 'line' from the prototype,.
Implements UpdatableComponent.
Definition at line 66 of file nnet-lstm-projected.h.
References LstmProjected::bias_, UpdatableComponent::bias_learn_rate_coef_, LstmProjected::cell_clip_, LstmProjected::cell_diff_clip_, LstmProjected::cell_dim_, LstmProjected::diff_clip_, LstmProjected::grad_clip_, Component::input_dim_, KALDI_ASSERT, KALDI_ERR, kaldi::kUndefined, UpdatableComponent::learn_rate_coef_, LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, LstmProjected::proj_dim_, kaldi::nnet1::RandUniform(), kaldi::ReadBasicType(), kaldi::ReadToken(), LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, and LstmProjected::w_r_m_.
|
inlinevirtual |
Number of trainable parameters,.
Implements UpdatableComponent.
Definition at line 181 of file nnet-lstm-projected.h.
References LstmProjected::bias_, LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, and LstmProjected::w_r_m_.
Referenced by LstmProjected::GetGradient(), LstmProjected::GetParams(), and LstmProjected::SetParams().
|
inlinevirtual |
Abstract interface for propagation/backpropagation.
Forward pass transformation (to be implemented by descending class...)
Implements Component.
Definition at line 365 of file nnet-lstm-projected.h.
References CuMatrixBase< Real >::AddMatMat(), LstmProjected::bias_, LstmProjected::cell_clip_, LstmProjected::cell_dim_, CuMatrixBase< Real >::CopyFromMat(), KALDI_ASSERT, kaldi::kNoTrans, kaldi::kSetZero, kaldi::kTrans, CuMatrixBase< Real >::NumRows(), MultistreamComponent::NumStreams(), LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, LstmProjected::prev_nnet_state_, LstmProjected::proj_dim_, LstmProjected::propagate_buf_, LstmProjected::ResetStreams(), CuMatrixBase< Real >::RowRange(), MultistreamComponent::sequence_lengths_, LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, and LstmProjected::w_r_m_.
|
inlinevirtual |
Reads the component content.
Reimplemented from Component.
Definition at line 107 of file nnet-lstm-projected.h.
References LstmProjected::bias_, UpdatableComponent::bias_learn_rate_coef_, LstmProjected::cell_clip_, LstmProjected::cell_diff_clip_, LstmProjected::cell_dim_, LstmProjected::diff_clip_, kaldi::ExpectToken(), LstmProjected::grad_clip_, KALDI_ASSERT, KALDI_ERR, UpdatableComponent::learn_rate_coef_, kaldi::Peek(), kaldi::PeekToken(), LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, kaldi::ReadBasicType(), kaldi::ReadToken(), LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, and LstmProjected::w_r_m_.
|
inlinevirtual |
TODO: Do we really need this?
Reimplemented from MultistreamComponent.
Definition at line 352 of file nnet-lstm-projected.h.
References LstmProjected::cell_dim_, KALDI_ASSERT, kaldi::kSetZero, MultistreamComponent::NumStreams(), LstmProjected::prev_nnet_state_, and LstmProjected::proj_dim_.
Referenced by LstmProjected::PropagateFnc().
|
inlinevirtual |
Set the trainable parameters from, reshaped as a vector,.
Implements UpdatableComponent.
Definition at line 249 of file nnet-lstm-projected.h.
References LstmProjected::bias_, VectorBase< Real >::Dim(), KALDI_ASSERT, LstmProjected::NumParams(), LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, VectorBase< Real >::Range(), LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, and LstmProjected::w_r_m_.
|
inlinevirtual |
Compute gradient and update parameters,.
Implements UpdatableComponent.
Definition at line 654 of file nnet-lstm-projected.h.
References LstmProjected::bias_, LstmProjected::bias_corr_, UpdatableComponent::bias_learn_rate_coef_, LstmProjected::grad_clip_, NnetTrainOptions::learn_rate, UpdatableComponent::learn_rate_coef_, UpdatableComponent::opts_, LstmProjected::peephole_f_c_, LstmProjected::peephole_f_c_corr_, LstmProjected::peephole_i_c_, LstmProjected::peephole_i_c_corr_, LstmProjected::peephole_o_c_, LstmProjected::peephole_o_c_corr_, LstmProjected::w_gifo_r_, LstmProjected::w_gifo_r_corr_, LstmProjected::w_gifo_x_, LstmProjected::w_gifo_x_corr_, LstmProjected::w_r_m_, and LstmProjected::w_r_m_corr_.
|
inlinevirtual |
Writes the component content.
Reimplemented from Component.
Definition at line 150 of file nnet-lstm-projected.h.
References LstmProjected::bias_, UpdatableComponent::bias_learn_rate_coef_, LstmProjected::cell_clip_, LstmProjected::cell_diff_clip_, LstmProjected::cell_dim_, LstmProjected::diff_clip_, LstmProjected::grad_clip_, UpdatableComponent::learn_rate_coef_, LstmProjected::peephole_f_c_, LstmProjected::peephole_i_c_, LstmProjected::peephole_o_c_, LstmProjected::w_gifo_r_, LstmProjected::w_gifo_x_, LstmProjected::w_r_m_, kaldi::WriteBasicType(), and kaldi::WriteToken().
Definition at line 731 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), and LstmProjected::InfoGradient().
Definition at line 710 of file nnet-lstm-projected.h.
Referenced by LstmProjected::GetGradient(), LstmProjected::GetParams(), LstmProjected::Info(), LstmProjected::InitData(), LstmProjected::NumParams(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::SetParams(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 711 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::InfoGradient(), and LstmProjected::Update().
|
private |
Clipping of 'cell-values' in forward pass (per-frame),.
Definition at line 693 of file nnet-lstm-projected.h.
Referenced by LstmProjected::Info(), LstmProjected::InfoGradient(), LstmProjected::InitData(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), and LstmProjected::WriteData().
|
private |
Clipping of 'cell-derivatives' accumulated over CEC (per-frame),.
Definition at line 695 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::InitData(), LstmProjected::ReadData(), and LstmProjected::WriteData().
|
private |
Definition at line 690 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::Info(), LstmProjected::InfoGradient(), LstmProjected::InitData(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::ResetStreams(), and LstmProjected::WriteData().
|
private |
Clipping of 'derivatives' in backprop (per-frame),.
Definition at line 694 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::Info(), LstmProjected::InfoGradient(), LstmProjected::InitData(), LstmProjected::ReadData(), and LstmProjected::WriteData().
|
private |
Clipping of the updates,.
Definition at line 696 of file nnet-lstm-projected.h.
Referenced by LstmProjected::Info(), LstmProjected::InfoGradient(), LstmProjected::InitData(), LstmProjected::ReadData(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 716 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::GetParams(), LstmProjected::Info(), LstmProjected::InitData(), LstmProjected::NumParams(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::SetParams(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 720 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::InfoGradient(), and LstmProjected::Update().
Definition at line 715 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::GetParams(), LstmProjected::Info(), LstmProjected::InitData(), LstmProjected::NumParams(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::SetParams(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 719 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::InfoGradient(), and LstmProjected::Update().
Definition at line 717 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::GetParams(), LstmProjected::Info(), LstmProjected::InitData(), LstmProjected::NumParams(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::SetParams(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 721 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::InfoGradient(), and LstmProjected::Update().
Definition at line 699 of file nnet-lstm-projected.h.
Referenced by LstmProjected::PropagateFnc(), and LstmProjected::ResetStreams().
|
private |
recurrent projection layer dim
Definition at line 691 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::InfoGradient(), LstmProjected::InitData(), LstmProjected::PropagateFnc(), and LstmProjected::ResetStreams().
Definition at line 728 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::InfoGradient(), and LstmProjected::PropagateFnc().
Definition at line 706 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::GetParams(), LstmProjected::Info(), LstmProjected::InitData(), LstmProjected::NumParams(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::SetParams(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 707 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::InfoGradient(), and LstmProjected::Update().
Definition at line 702 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::GetParams(), LstmProjected::Info(), LstmProjected::InitData(), LstmProjected::NumParams(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::SetParams(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 703 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::InfoGradient(), and LstmProjected::Update().
Definition at line 724 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::GetParams(), LstmProjected::Info(), LstmProjected::InitData(), LstmProjected::NumParams(), LstmProjected::PropagateFnc(), LstmProjected::ReadData(), LstmProjected::SetParams(), LstmProjected::Update(), and LstmProjected::WriteData().
Definition at line 725 of file nnet-lstm-projected.h.
Referenced by LstmProjected::BackpropagateFnc(), LstmProjected::GetGradient(), LstmProjected::InfoGradient(), and LstmProjected::Update().