RecurrentComponent Class Reference

Component with recurrent connections, 'tanh' non-linearity. More...

#include <nnet-recurrent.h>

Inheritance diagram for RecurrentComponent:
Collaboration diagram for RecurrentComponent:

Public Member Functions

 RecurrentComponent (int32 input_dim, int32 output_dim)
 
 ~RecurrentComponent ()
 
ComponentCopy () const
 Copy component (deep copy),. More...
 
ComponentType GetType () const
 Get Type Identification of the component,. More...
 
void InitData (std::istream &is)
 Initialize the content of the component by the 'line' from the prototype,. More...
 
void ReadData (std::istream &is, bool binary)
 Reads the component content. More...
 
void WriteData (std::ostream &os, bool binary) const
 Writes the component content. More...
 
int32 NumParams () const
 Number of trainable parameters,. More...
 
void GetGradient (VectorBase< BaseFloat > *gradient) const
 Get gradient reshaped as a vector,. More...
 
void GetParams (VectorBase< BaseFloat > *params) const
 Get the trainable parameters reshaped as a vector,. More...
 
void SetParams (const VectorBase< BaseFloat > &params)
 Set the trainable parameters from, reshaped as a vector,. More...
 
std::string Info () const
 Print some additional info (after <ComponentName> and the dims),. More...
 
std::string InfoGradient () const
 Print some additional info about gradient (after <...> and dims),. More...
 
void PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
 Abstract interface for propagation/backpropagation. More...
 
void BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
 Backward pass transformation (to be implemented by descending class...) More...
 
void Update (const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff)
 Compute gradient and update parameters,. More...
 
- Public Member Functions inherited from MultistreamComponent
 MultistreamComponent (int32 input_dim, int32 output_dim)
 
bool IsMultistream () const
 Check if component has 'Recurrent' interface (trainable and recurrent),. More...
 
virtual void SetSeqLengths (const std::vector< int32 > &sequence_lengths)
 
int32 NumStreams () const
 
virtual void ResetStreams (const std::vector< int32 > &stream_reset_flag)
 Optional function to reset the transfer of context (not used for BLSTMs. More...
 
- Public Member Functions inherited from UpdatableComponent
 UpdatableComponent (int32 input_dim, int32 output_dim)
 
virtual ~UpdatableComponent ()
 
bool IsUpdatable () const
 Check if contains trainable parameters,. More...
 
virtual void SetTrainOptions (const NnetTrainOptions &opts)
 Set the training options to the component,. More...
 
const NnetTrainOptionsGetTrainOptions () const
 Get the training options from the component,. More...
 
virtual void SetLearnRateCoef (BaseFloat val)
 Set the learn-rate coefficient,. More...
 
virtual void SetBiasLearnRateCoef (BaseFloat val)
 Set the learn-rate coefficient for bias,. More...
 
- Public Member Functions inherited from Component
 Component (int32 input_dim, int32 output_dim)
 Generic interface of a component,. More...
 
virtual ~Component ()
 
int32 InputDim () const
 Get the dimension of the input,. More...
 
int32 OutputDim () const
 Get the dimension of the output,. More...
 
void Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
 Perform forward-pass propagation 'in' -> 'out',. More...
 
void Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
 Perform backward-pass propagation 'out_diff' -> 'in_diff'. More...
 
void Write (std::ostream &os, bool binary) const
 Write the component to a stream,. More...
 

Private Attributes

BaseFloat grad_clip_
 Clipping of the update,. More...
 
BaseFloat diff_clip_
 Clipping in the BPTT loop,. More...
 
CuMatrix< BaseFloatw_forward_
 
CuMatrix< BaseFloatw_recurrent_
 
CuVector< BaseFloatbias_
 
CuMatrix< BaseFloatw_forward_corr_
 
CuMatrix< BaseFloatw_recurrent_corr_
 
CuVector< BaseFloatbias_corr_
 
CuMatrix< BaseFloatout_
 
CuMatrix< BaseFloatout_diff_bptt_
 

Additional Inherited Members

- Public Types inherited from Component
enum  ComponentType {
  kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform,
  kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent,
  kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax,
  kSigmoid, kTanh, kParametricRelu, kDropout,
  kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice,
  kCopy, kTranspose, kBlockLinearity, kAddShift,
  kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent,
  kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent,
  kMultiBasisComponent
}
 Component type identification mechanism,. More...
 
- Static Public Member Functions inherited from Component
static const char * TypeToMarker (ComponentType t)
 Converts component type to marker,. More...
 
static ComponentType MarkerToType (const std::string &s)
 Converts marker to component type (case insensitive),. More...
 
static ComponentInit (const std::string &conf_line)
 Initialize component from a line in config file,. More...
 
static ComponentRead (std::istream &is, bool binary)
 Read the component from a stream (static method),. More...
 
- Static Public Attributes inherited from Component
static const struct key_value kMarkerMap []
 The table with pairs of Component types and markers (defined in nnet-component.cc),. More...
 
- Protected Attributes inherited from MultistreamComponent
std::vector< int32sequence_lengths_
 
- Protected Attributes inherited from UpdatableComponent
NnetTrainOptions opts_
 Option-class with training hyper-parameters,. More...
 
BaseFloat learn_rate_coef_
 Scalar applied to learning rate for weight matrices (to be used in ::Update method),. More...
 
BaseFloat bias_learn_rate_coef_
 Scalar applied to learning rate for bias (to be used in ::Update method),. More...
 
- Protected Attributes inherited from Component
int32 input_dim_
 Data members,. More...
 
int32 output_dim_
 Dimension of the output of the Component,. More...
 

Detailed Description

Component with recurrent connections, 'tanh' non-linearity.

No internal state preserved, starting each sequence from zero vector.

Can be used in 'per-sentence' training and multi-stream training.

Definition at line 43 of file nnet-recurrent.h.

Constructor & Destructor Documentation

◆ RecurrentComponent()

RecurrentComponent ( int32  input_dim,
int32  output_dim 
)
inline

Definition at line 45 of file nnet-recurrent.h.

Referenced by RecurrentComponent::Copy().

45  :
46  MultistreamComponent(input_dim, output_dim)
47  { }
MultistreamComponent(int32 input_dim, int32 output_dim)

◆ ~RecurrentComponent()

~RecurrentComponent ( )
inline

Definition at line 49 of file nnet-recurrent.h.

50  { }

Member Function Documentation

◆ BackpropagateFnc()

void BackpropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
const CuMatrixBase< BaseFloat > &  out,
const CuMatrixBase< BaseFloat > &  out_diff,
CuMatrixBase< BaseFloat > *  in_diff 
)
inlinevirtual

Backward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 243 of file nnet-recurrent.h.

References CuMatrixBase< Real >::AddMatMat(), CuMatrixBase< Real >::ApplyCeiling(), CuMatrixBase< Real >::ApplyFloor(), RecurrentComponent::diff_clip_, CuMatrixBase< Real >::DiffTanh(), kaldi::kNoTrans, CuMatrixBase< Real >::NumRows(), MultistreamComponent::NumStreams(), RecurrentComponent::out_diff_bptt_, CuMatrixBase< Real >::RowRange(), MultistreamComponent::sequence_lengths_, RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

246  {
247 
248  int32 T = in.NumRows() / NumStreams();
249  int32 S = NumStreams();
250 
251  // Apply BPTT on 'out_diff',
252  out_diff_bptt_ = out_diff;
253  for (int32 t = T-1; t >= 1; t--) {
254  // buffers,
255  CuSubMatrix<BaseFloat> d_t = out_diff_bptt_.RowRange(t*S, S);
256  CuSubMatrix<BaseFloat> d_t1 = out_diff_bptt_.RowRange((t-1)*S, S);
257  const CuSubMatrix<BaseFloat> y_t = out.RowRange(t*S, S);
258 
259  // BPTT,
260  d_t.DiffTanh(y_t, d_t);
261  d_t1.AddMatMat(1.0, d_t, kNoTrans, w_recurrent_, kNoTrans, 1.0);
262 
263  // clipping,
264  if (diff_clip_ > 0.0) {
265  d_t1.ApplyFloor(-diff_clip_);
266  d_t1.ApplyCeiling(diff_clip_);
267  }
268 
269  // Zero diff for padded frames,
270  if (sequence_lengths_.size() == S) {
271  for (int32 s = 0; s < S; s++) {
272  if (t >= sequence_lengths_[s]) {
273  out_diff_bptt_.Row(t*S + s).SetZero();
274  }
275  }
276  }
277  }
278 
279  // Apply 'DiffTanh' on first block,
280  CuSubMatrix<BaseFloat> d_t = out_diff_bptt_.RowRange(0, S);
281  const CuSubMatrix<BaseFloat> y_t = out.RowRange(0, S);
282  d_t.DiffTanh(y_t, d_t);
283 
284  // Transform diffs to 'in_diff',
285  in_diff->AddMatMat(1.0, out_diff_bptt_, kNoTrans, w_forward_, kNoTrans, 0.0);
286 
287  // We are DONE ;)
288  }
kaldi::int32 int32
BaseFloat diff_clip_
Clipping in the BPTT loop,.
CuMatrix< BaseFloat > w_recurrent_
CuMatrix< BaseFloat > out_diff_bptt_
std::vector< int32 > sequence_lengths_
CuMatrix< BaseFloat > w_forward_

◆ Copy()

Component* Copy ( ) const
inlinevirtual

Copy component (deep copy),.

Implements Component.

Definition at line 52 of file nnet-recurrent.h.

References RecurrentComponent::RecurrentComponent().

52 { return new RecurrentComponent(*this); }
RecurrentComponent(int32 input_dim, int32 output_dim)

◆ GetGradient()

void GetGradient ( VectorBase< BaseFloat > *  gradient) const
inlinevirtual

Get gradient reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 133 of file nnet-recurrent.h.

References RecurrentComponent::bias_corr_, VectorBase< Real >::Dim(), KALDI_ASSERT, RecurrentComponent::NumParams(), VectorBase< Real >::Range(), RecurrentComponent::w_forward_corr_, and RecurrentComponent::w_recurrent_corr_.

133  {
134  KALDI_ASSERT(gradient->Dim() == NumParams());
135  int32 offset, len;
136 
137  offset = 0; len = w_forward_corr_.NumRows() * w_forward_corr_.NumCols();
138  gradient->Range(offset, len).CopyRowsFromMat(w_forward_corr_);
139 
140  offset += len; len = w_recurrent_corr_.NumRows() * w_recurrent_corr_.NumCols();
141  gradient->Range(offset, len).CopyRowsFromMat(w_recurrent_corr_);
142 
143  offset += len; len = bias_corr_.Dim();
144  gradient->Range(offset, len).CopyFromVec(bias_corr_);
145 
146  offset += len;
147  KALDI_ASSERT(offset == NumParams());
148  }
CuVector< BaseFloat > bias_corr_
kaldi::int32 int32
CuMatrix< BaseFloat > w_recurrent_corr_
CuMatrix< BaseFloat > w_forward_corr_
int32 NumParams() const
Number of trainable parameters,.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetParams()

void GetParams ( VectorBase< BaseFloat > *  params) const
inlinevirtual

Get the trainable parameters reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 150 of file nnet-recurrent.h.

References RecurrentComponent::bias_, VectorBase< Real >::Dim(), KALDI_ASSERT, RecurrentComponent::NumParams(), VectorBase< Real >::Range(), RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

150  {
151  KALDI_ASSERT(params->Dim() == NumParams());
152  int32 offset, len;
153 
154  offset = 0; len = w_forward_.NumRows() * w_forward_.NumCols();
155  params->Range(offset, len).CopyRowsFromMat(w_forward_);
156 
157  offset += len; len = w_recurrent_.NumRows() * w_recurrent_.NumCols();
158  params->Range(offset, len).CopyRowsFromMat(w_recurrent_);
159 
160  offset += len; len = bias_.Dim();
161  params->Range(offset, len).CopyFromVec(bias_);
162 
163  offset += len;
164  KALDI_ASSERT(offset == NumParams());
165  }
kaldi::int32 int32
CuMatrix< BaseFloat > w_recurrent_
int32 NumParams() const
Number of trainable parameters,.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
CuMatrix< BaseFloat > w_forward_

◆ GetType()

ComponentType GetType ( ) const
inlinevirtual

Get Type Identification of the component,.

Implements Component.

Definition at line 53 of file nnet-recurrent.h.

References Component::kRecurrentComponent.

◆ Info()

std::string Info ( ) const
inlinevirtual

Print some additional info (after <ComponentName> and the dims),.

Reimplemented from Component.

Definition at line 184 of file nnet-recurrent.h.

References RecurrentComponent::bias_, kaldi::nnet1::MomentStatistics(), RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

184  {
185  return std::string(" ") +
186  "\n w_forward_ " + MomentStatistics(w_forward_) +
187  "\n w_recurrent_ " + MomentStatistics(w_recurrent_) +
188  "\n bias_ " + MomentStatistics(bias_);
189  }
std::string MomentStatistics(const VectorBase< Real > &vec)
Get a string with statistics of the data in a vector, so we can print them easily.
Definition: nnet-utils.h:63
CuMatrix< BaseFloat > w_recurrent_
CuMatrix< BaseFloat > w_forward_

◆ InfoGradient()

std::string InfoGradient ( ) const
inlinevirtual

Print some additional info about gradient (after <...> and dims),.

Reimplemented from Component.

Definition at line 191 of file nnet-recurrent.h.

References RecurrentComponent::bias_corr_, UpdatableComponent::bias_learn_rate_coef_, RecurrentComponent::diff_clip_, RecurrentComponent::grad_clip_, UpdatableComponent::learn_rate_coef_, kaldi::nnet1::MomentStatistics(), RecurrentComponent::out_, RecurrentComponent::out_diff_bptt_, kaldi::nnet1::ToString(), RecurrentComponent::w_forward_corr_, and RecurrentComponent::w_recurrent_corr_.

191  {
192  return std::string("") +
193  "( learn_rate_coef " + ToString(learn_rate_coef_) +
194  ", bias_learn_rate_coef " + ToString(bias_learn_rate_coef_) +
195  ", grad-clip " + ToString(grad_clip_) +
196  ", diff-clip " + ToString(diff_clip_) + " )" +
197  "\n Gradients:" +
198  "\n w_forward_corr_ " + MomentStatistics(w_forward_corr_) +
199  "\n w_recurrent_corr_ " + MomentStatistics(w_recurrent_corr_) +
200  "\n bias_corr_ " + MomentStatistics(bias_corr_) +
201  "\n Forward-pass:" +
202  "\n out_ " + MomentStatistics(out_) +
203  "\n Backward-pass:" +
204  "\n out_diff_bptt_ " + MomentStatistics(out_diff_bptt_);
205  }
std::string ToString(const T &t)
Convert basic type to a string (please don&#39;t overuse),.
Definition: nnet-utils.h:52
std::string MomentStatistics(const VectorBase< Real > &vec)
Get a string with statistics of the data in a vector, so we can print them easily.
Definition: nnet-utils.h:63
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
CuVector< BaseFloat > bias_corr_
CuMatrix< BaseFloat > w_recurrent_corr_
BaseFloat diff_clip_
Clipping in the BPTT loop,.
CuMatrix< BaseFloat > out_diff_bptt_
CuMatrix< BaseFloat > w_forward_corr_
BaseFloat grad_clip_
Clipping of the update,.

◆ InitData()

void InitData ( std::istream &  is)
inlinevirtual

Initialize the content of the component by the 'line' from the prototype,.

Implements UpdatableComponent.

Definition at line 55 of file nnet-recurrent.h.

References RecurrentComponent::bias_, UpdatableComponent::bias_learn_rate_coef_, RecurrentComponent::diff_clip_, RecurrentComponent::grad_clip_, Component::input_dim_, KALDI_ERR, UpdatableComponent::learn_rate_coef_, Component::output_dim_, kaldi::nnet1::RandUniform(), kaldi::ReadBasicType(), kaldi::ReadToken(), RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

55  {
56  // define options,
57  float param_scale = 0.02;
58  // parse the line from prototype,
59  std::string token;
60  while (is >> std::ws, !is.eof()) {
61  ReadToken(is, false, &token);
62  if (token == "<GradClip>") ReadBasicType(is, false, &grad_clip_);
63  else if (token == "<DiffClip>") ReadBasicType(is, false, &diff_clip_);
64  else if (token == "<LearnRateCoef>") ReadBasicType(is, false, &learn_rate_coef_);
65  else if (token == "<BiasLearnRateCoef>") ReadBasicType(is, false, &bias_learn_rate_coef_);
66  else if (token == "<ParamScale>") ReadBasicType(is, false, &param_scale);
67  else KALDI_ERR << "Unknown token " << token << ", a typo in config?"
68  << " (GradClip|DiffClip|LearnRateCoef|BiasLearnRateCoef|ParamScale)";
69  }
70 
71  // init the weights and biases (from uniform dist.),
74  bias_.Resize(output_dim_);
75 
76  RandUniform(0.0, 2.0 * param_scale, &w_forward_);
77  RandUniform(0.0, 2.0 * param_scale, &w_recurrent_);
78  RandUniform(0.0, 2.0 * param_scale, &bias_);
79  }
int32 input_dim_
Data members,.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
void RandUniform(BaseFloat mu, BaseFloat range, CuMatrixBase< Real > *mat, struct RandomState *state=NULL)
Fill CuMatrix with random numbers (Uniform distribution): mu = the mean value, range = the &#39;width&#39; of...
Definition: nnet-utils.h:188
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
BaseFloat diff_clip_
Clipping in the BPTT loop,.
CuMatrix< BaseFloat > w_recurrent_
#define KALDI_ERR
Definition: kaldi-error.h:147
int32 output_dim_
Dimension of the output of the Component,.
BaseFloat grad_clip_
Clipping of the update,.
CuMatrix< BaseFloat > w_forward_

◆ NumParams()

int32 NumParams ( ) const
inlinevirtual

Number of trainable parameters,.

Implements UpdatableComponent.

Definition at line 127 of file nnet-recurrent.h.

References RecurrentComponent::bias_, RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

Referenced by RecurrentComponent::GetGradient(), RecurrentComponent::GetParams(), and RecurrentComponent::SetParams().

127  {
128  return w_forward_.NumRows() * w_forward_.NumCols() +
129  w_recurrent_.NumRows() * w_recurrent_.NumCols() +
130  bias_.Dim();
131  }
CuMatrix< BaseFloat > w_recurrent_
CuMatrix< BaseFloat > w_forward_

◆ PropagateFnc()

void PropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
)
inlinevirtual

Abstract interface for propagation/backpropagation.

Forward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 207 of file nnet-recurrent.h.

References CuMatrixBase< Real >::AddMatMat(), CuMatrixBase< Real >::AddVecToRows(), RecurrentComponent::bias_, KALDI_ASSERT, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::NumRows(), MultistreamComponent::NumStreams(), RecurrentComponent::out_, CuMatrixBase< Real >::Row(), CuMatrixBase< Real >::RowRange(), MultistreamComponent::sequence_lengths_, RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

208  {
209 
210 
211  KALDI_ASSERT(in.NumRows() % NumStreams() == 0);
212  int32 T = in.NumRows() / NumStreams();
213  int32 S = NumStreams();
214 
215  // Precopy bias,
216  out->AddVecToRows(1.0, bias_, 0.0);
217  // Apply 'forward' connections,
218  out->AddMatMat(1.0, in, kNoTrans, w_forward_, kTrans, 1.0);
219 
220  // First line of 'out' w/o recurrent signal, apply 'tanh' directly,
221  out->RowRange(0, S).Tanh(out->RowRange(0, S));
222 
223  // Apply 'recurrent' connections,
224  for (int32 t = 1; t < T; t++) {
225  out->RowRange(t*S, S).AddMatMat(1.0, out->RowRange((t-1)*S, S), kNoTrans, w_recurrent_, kTrans, 1.0);
226  out->RowRange(t*S, S).Tanh(out->RowRange(t*S, S));
227  // Zero output for padded frames,
228  if (sequence_lengths_.size() == S) {
229  for (int32 s = 0; s < S; s++) {
230  if (t >= sequence_lengths_[s]) {
231  out->Row(t*S + s).SetZero();
232  }
233  }
234  }
235  //
236  }
237 
238  out_ = (*out); // We'll need a copy for updating the recurrent weights!
239 
240  // We are DONE ;)
241  }
kaldi::int32 int32
CuMatrix< BaseFloat > w_recurrent_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< int32 > sequence_lengths_
CuMatrix< BaseFloat > w_forward_

◆ ReadData()

void ReadData ( std::istream &  is,
bool  binary 
)
inlinevirtual

Reads the component content.

Reimplemented from Component.

Definition at line 81 of file nnet-recurrent.h.

References RecurrentComponent::bias_, UpdatableComponent::bias_learn_rate_coef_, RecurrentComponent::diff_clip_, kaldi::ExpectToken(), RecurrentComponent::grad_clip_, KALDI_ERR, UpdatableComponent::learn_rate_coef_, kaldi::Peek(), kaldi::PeekToken(), kaldi::ReadBasicType(), kaldi::ReadToken(), RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

81  {
82  // Read all the '<Tokens>' in arbitrary order,
83  while ('<' == Peek(is, binary)) {
84  std::string token;
85  int first_char = PeekToken(is, binary);
86  switch (first_char) {
87  case 'G': ExpectToken(is, binary, "<GradClip>");
88  ReadBasicType(is, binary, &grad_clip_);
89  break;
90  case 'D': ExpectToken(is, binary, "<DiffClip>");
91  ReadBasicType(is, binary, &diff_clip_);
92  break;
93  case 'L': ExpectToken(is, binary, "<LearnRateCoef>");
94  ReadBasicType(is, binary, &learn_rate_coef_);
95  break;
96  case 'B': ExpectToken(is, binary, "<BiasLearnRateCoef>");
98  break;
99  default: ReadToken(is, false, &token);
100  KALDI_ERR << "Unknown token: " << token;
101  }
102  }
103 
104  // Read the data (data follow the tokens),
105  w_forward_.Read(is, binary);
106  w_recurrent_.Read(is, binary);
107  bias_.Read(is, binary);
108  }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
int Peek(std::istream &is, bool binary)
Peek consumes whitespace (if binary == false) and then returns the peek() value of the stream...
Definition: io-funcs.cc:145
BaseFloat diff_clip_
Clipping in the BPTT loop,.
CuMatrix< BaseFloat > w_recurrent_
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
#define KALDI_ERR
Definition: kaldi-error.h:147
int PeekToken(std::istream &is, bool binary)
PeekToken will return the first character of the next token, or -1 if end of file.
Definition: io-funcs.cc:170
BaseFloat grad_clip_
Clipping of the update,.
CuMatrix< BaseFloat > w_forward_

◆ SetParams()

void SetParams ( const VectorBase< BaseFloat > &  params)
inlinevirtual

Set the trainable parameters from, reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 167 of file nnet-recurrent.h.

References RecurrentComponent::bias_, VectorBase< Real >::Dim(), KALDI_ASSERT, RecurrentComponent::NumParams(), VectorBase< Real >::Range(), RecurrentComponent::w_forward_, and RecurrentComponent::w_recurrent_.

167  {
168  KALDI_ASSERT(params.Dim() == NumParams());
169  int32 offset, len;
170 
171  offset = 0; len = w_forward_.NumRows() * w_forward_.NumCols();
172  w_forward_.CopyRowsFromVec(params.Range(offset, len));
173 
174  offset += len; len = w_recurrent_.NumRows() * w_recurrent_.NumCols();
175  w_recurrent_.CopyRowsFromVec(params.Range(offset, len));
176 
177  offset += len; len = bias_.Dim();
178  bias_.CopyFromVec(params.Range(offset, len));
179 
180  offset += len;
181  KALDI_ASSERT(offset == NumParams());
182  }
kaldi::int32 int32
CuMatrix< BaseFloat > w_recurrent_
int32 NumParams() const
Number of trainable parameters,.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
CuMatrix< BaseFloat > w_forward_

◆ Update()

void Update ( const CuMatrixBase< BaseFloat > &  input,
const CuMatrixBase< BaseFloat > &  diff 
)
inlinevirtual

Compute gradient and update parameters,.

Implements UpdatableComponent.

Definition at line 290 of file nnet-recurrent.h.

References RecurrentComponent::bias_, RecurrentComponent::bias_corr_, UpdatableComponent::bias_learn_rate_coef_, kaldi::kNoTrans, kaldi::kSetZero, kaldi::kTrans, NnetTrainOptions::learn_rate, UpdatableComponent::learn_rate_coef_, NnetTrainOptions::momentum, CuMatrixBase< Real >::NumRows(), MultistreamComponent::NumStreams(), UpdatableComponent::opts_, RecurrentComponent::out_, RecurrentComponent::out_diff_bptt_, Component::OutputDim(), RecurrentComponent::w_forward_, RecurrentComponent::w_forward_corr_, RecurrentComponent::w_recurrent_, and RecurrentComponent::w_recurrent_corr_.

291  {
292  int32 T = input.NumRows() / NumStreams();
293  int32 S = NumStreams();
294 
295  // getting the learning rate,
296  const BaseFloat lr = opts_.learn_rate;
297  const BaseFloat mmt = opts_.momentum;
298 
299  if (bias_corr_.Dim() != OutputDim()) {
300  w_forward_corr_.Resize(w_forward_.NumRows(), w_forward_.NumCols(), kSetZero);
301  w_recurrent_corr_.Resize(w_recurrent_.NumRows(), w_recurrent_.NumCols(), kSetZero);
302  bias_corr_.Resize(OutputDim(), kSetZero);
303  }
304 
305  // getting the gradients,
306  w_forward_corr_.AddMatMat(1.0, out_diff_bptt_, kTrans, input, kNoTrans, mmt);
307 
308 
309  w_recurrent_corr_.AddMatMat(1.0, out_diff_bptt_.RowRange(S, (T-1)*S), kTrans,
310  out_.RowRange(0, (T-1)*S), kNoTrans, mmt);
311 
312  bias_corr_.AddRowSumMat(1.0, out_diff_bptt_, mmt);
313 
314  // updating,
316  w_recurrent_.AddMat(-lr * learn_rate_coef_, w_recurrent_corr_);
317  bias_.AddVec(-lr * bias_learn_rate_coef_, bias_corr_);
318  }
NnetTrainOptions opts_
Option-class with training hyper-parameters,.
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
CuVector< BaseFloat > bias_corr_
kaldi::int32 int32
CuMatrix< BaseFloat > w_recurrent_corr_
float BaseFloat
Definition: kaldi-types.h:29
CuMatrix< BaseFloat > w_recurrent_
CuMatrix< BaseFloat > out_diff_bptt_
CuMatrix< BaseFloat > w_forward_corr_
int32 OutputDim() const
Get the dimension of the output,.
CuMatrix< BaseFloat > w_forward_

◆ WriteData()

void WriteData ( std::ostream &  os,
bool  binary 
) const
inlinevirtual

Writes the component content.

Reimplemented from Component.

Definition at line 110 of file nnet-recurrent.h.

References RecurrentComponent::bias_, UpdatableComponent::bias_learn_rate_coef_, RecurrentComponent::diff_clip_, RecurrentComponent::grad_clip_, UpdatableComponent::learn_rate_coef_, RecurrentComponent::w_forward_, RecurrentComponent::w_recurrent_, kaldi::WriteBasicType(), and kaldi::WriteToken().

110  {
111  WriteToken(os, binary, "<GradClip>");
112  WriteBasicType(os, binary, grad_clip_);
113  WriteToken(os, binary, "<DiffClip>");
114  WriteBasicType(os, binary, diff_clip_);
115 
116  WriteToken(os, binary, "<LearnRateCoef>");
117  WriteBasicType(os, binary, learn_rate_coef_);
118  WriteToken(os, binary, "<BiasLearnRateCoef>");
120 
121  if (!binary) os << "\n";
122  w_forward_.Write(os, binary);
123  w_recurrent_.Write(os, binary);
124  bias_.Write(os, binary);
125  }
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
BaseFloat diff_clip_
Clipping in the BPTT loop,.
CuMatrix< BaseFloat > w_recurrent_
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
BaseFloat grad_clip_
Clipping of the update,.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
CuMatrix< BaseFloat > w_forward_

Member Data Documentation

◆ bias_

◆ bias_corr_

◆ diff_clip_

◆ grad_clip_

BaseFloat grad_clip_
private

◆ out_

◆ out_diff_bptt_

◆ w_forward_

◆ w_forward_corr_

◆ w_recurrent_

◆ w_recurrent_corr_


The documentation for this class was generated from the following file: