#include <nnet-nnet.h>
Public Member Functions | |
Nnet () | |
~Nnet () | |
Nnet (const Nnet &other) | |
Nnet & | operator= (const Nnet &other) |
void | Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) |
Perform forward pass through the network,. More... | |
void | Backpropagate (const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff) |
Perform backward pass through the network,. More... | |
void | Feedforward (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) |
Perform forward pass through the network (with 2 swapping buffers),. More... | |
int32 | InputDim () const |
Dimensionality on network input (input feature dim.),. More... | |
int32 | OutputDim () const |
Dimensionality of network outputs (posteriors | bn-features | etc.),. More... | |
int32 | NumComponents () const |
Returns the number of 'Components' which form the NN. More... | |
const Component & | GetComponent (int32 c) const |
Component accessor,. More... | |
Component & | GetComponent (int32 c) |
Component accessor,. More... | |
const Component & | GetLastComponent () const |
LastComponent accessor,. More... | |
Component & | GetLastComponent () |
LastComponent accessor,. More... | |
void | ReplaceComponent (int32 c, const Component &comp) |
Replace c'th component in 'this' Nnet (deep copy),. More... | |
void | SwapComponent (int32 c, Component **comp) |
Swap c'th component with the pointer,. More... | |
void | AppendComponent (const Component &comp) |
Append Component to 'this' instance of Nnet (deep copy),. More... | |
void | AppendComponentPointer (Component *dynamically_allocated_comp) |
Append Component* to 'this' instance of Nnet by a shallow copy ('this' instance of Nnet over-takes the ownership of the pointer). More... | |
void | AppendNnet (const Nnet &nnet_to_append) |
Append other Nnet to the 'this' Nnet (copy all its components),. More... | |
void | RemoveComponent (int32 c) |
Remove c'th component,. More... | |
void | RemoveLastComponent () |
Remove the last of the Components,. More... | |
const std::vector< CuMatrix< BaseFloat > > & | PropagateBuffer () const |
Access to the forward-pass buffers. More... | |
const std::vector< CuMatrix< BaseFloat > > & | BackpropagateBuffer () const |
Access to the backward-pass buffers. More... | |
int32 | NumParams () const |
Get the number of parameters in the network,. More... | |
void | GetGradient (Vector< BaseFloat > *gradient) const |
Get the gradient stored in the network,. More... | |
void | GetParams (Vector< BaseFloat > *params) const |
Get the network weights in a supervector,. More... | |
void | SetParams (const VectorBase< BaseFloat > ¶ms) |
Set the network weights from a supervector,. More... | |
void | SetDropoutRate (BaseFloat r) |
Set the dropout rate. More... | |
void | ResetStreams (const std::vector< int32 > &stream_reset_flag) |
Reset streams in multi-stream training,. More... | |
void | SetSeqLengths (const std::vector< int32 > &sequence_lengths) |
Set sequence length in LSTM multi-stream training,. More... | |
void | Init (const std::string &proto_file) |
Initialize the Nnet from the prototype,. More... | |
void | Read (const std::string &rxfilename) |
Read Nnet from 'rxfilename',. More... | |
void | Read (std::istream &in, bool binary) |
Read Nnet from 'istream',. More... | |
void | Write (const std::string &wxfilename, bool binary) const |
Write Nnet to 'wxfilename',. More... | |
void | Write (std::ostream &out, bool binary) const |
Write Nnet to 'ostream',. More... | |
std::string | Info () const |
Create string with human readable description of the nnet,. More... | |
std::string | InfoGradient (bool header=true) const |
Create string with per-component gradient statistics,. More... | |
std::string | InfoPropagate (bool header=true) const |
Create string with propagation-buffer statistics,. More... | |
std::string | InfoBackPropagate (bool header=true) const |
Create string with back-propagation-buffer statistics,. More... | |
void | Check () const |
Consistency check,. More... | |
void | Destroy () |
Relese the memory,. More... | |
void | SetTrainOptions (const NnetTrainOptions &opts) |
Set hyper-parameters of the training (pushes to all UpdatableComponents),. More... | |
const NnetTrainOptions & | GetTrainOptions () const |
Get training hyper-parameters from the network,. More... | |
Private Attributes | |
std::vector< Component * > | components_ |
Vector which contains all the components composing the neural network, the components are for example: AffineTransform, Sigmoid, Softmax. More... | |
std::vector< CuMatrix< BaseFloat > > | propagate_buf_ |
Buffers for forward pass (on demand initialization),. More... | |
std::vector< CuMatrix< BaseFloat > > | backpropagate_buf_ |
Buffers for backward pass (on demand initialization),. More... | |
NnetTrainOptions | opts_ |
Option class with hyper-parameters passed to UpdatableComponent(s) More... | |
Definition at line 37 of file nnet-nnet.h.
Nnet | ( | ) |
Definition at line 31 of file nnet-nnet.cc.
~Nnet | ( | ) |
Definition at line 38 of file nnet-nnet.cc.
References Nnet::backpropagate_buf_, Nnet::Check(), Nnet::components_, Component::Copy(), Nnet::GetComponent(), rnnlm::i, Nnet::NumComponents(), Nnet::opts_, Nnet::propagate_buf_, and Nnet::SetTrainOptions().
void AppendComponent | ( | const Component & | comp | ) |
Append Component to 'this' instance of Nnet (deep copy),.
Definition at line 182 of file nnet-nnet.cc.
References Nnet::Check(), Nnet::components_, and Component::Copy().
Referenced by Nnet::AppendNnet(), main(), and Nnet::NumComponents().
void AppendComponentPointer | ( | Component * | dynamically_allocated_comp | ) |
Append Component* to 'this' instance of Nnet by a shallow copy ('this' instance of Nnet over-takes the ownership of the pointer).
Definition at line 187 of file nnet-nnet.cc.
References Nnet::Check(), and Nnet::components_.
Referenced by Nnet::Init(), main(), Nnet::NumComponents(), and Nnet::Read().
void AppendNnet | ( | const Nnet & | nnet_to_append | ) |
Append other Nnet to the 'this' Nnet (copy all its components),.
Definition at line 192 of file nnet-nnet.cc.
References Nnet::AppendComponent(), Nnet::Check(), Nnet::GetComponent(), rnnlm::i, and Nnet::NumComponents().
Referenced by main(), and Nnet::NumComponents().
void Backpropagate | ( | const CuMatrixBase< BaseFloat > & | out_diff, |
CuMatrix< BaseFloat > * | in_diff | ||
) |
Perform backward pass through the network,.
Error back-propagation through the network, (from last component to first).
Definition at line 96 of file nnet-nnet.cc.
References Nnet::backpropagate_buf_, Nnet::components_, rnnlm::i, KALDI_ASSERT, Nnet::NumComponents(), Nnet::propagate_buf_, and UpdatableComponent::Update().
Referenced by MultiBasisComponent::BackpropagateFnc(), and main().
Access to the backward-pass buffers.
Definition at line 109 of file nnet-nnet.h.
References Nnet::backpropagate_buf_, Nnet::Check(), Nnet::Destroy(), Nnet::GetGradient(), Nnet::GetParams(), Nnet::Info(), Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), Nnet::Init(), Nnet::NumParams(), Nnet::Read(), Nnet::ResetStreams(), Nnet::SetDropoutRate(), Nnet::SetParams(), Nnet::SetSeqLengths(), Nnet::SetTrainOptions(), and Nnet::Write().
void Check | ( | ) | const |
Consistency check,.
Definition at line 467 of file nnet-nnet.cc.
References Nnet::components_, Nnet::GetParams(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, KALDI_ISINF, KALDI_ISNAN, VectorBase< Real >::Sum(), and Component::TypeToMarker().
Referenced by Nnet::AppendComponent(), Nnet::AppendComponentPointer(), Nnet::AppendNnet(), Nnet::BackpropagateBuffer(), Nnet::Init(), Nnet::Nnet(), Nnet::operator=(), Nnet::Read(), Nnet::RemoveComponent(), Nnet::ReplaceComponent(), Nnet::SwapComponent(), and Nnet::Write().
void Destroy | ( | ) |
Relese the memory,.
Definition at line 498 of file nnet-nnet.cc.
References Nnet::backpropagate_buf_, Nnet::components_, rnnlm::i, Nnet::NumComponents(), and Nnet::propagate_buf_.
Referenced by Nnet::BackpropagateBuffer(), Nnet::operator=(), and Nnet::~Nnet().
void Feedforward | ( | const CuMatrixBase< BaseFloat > & | in, |
CuMatrix< BaseFloat > * | out | ||
) |
Perform forward pass through the network (with 2 swapping buffers),.
Definition at line 131 of file nnet-nnet.cc.
References Nnet::components_, rnnlm::i, KALDI_ASSERT, Nnet::NumComponents(), and CuMatrix< Real >::Swap().
Referenced by main().
Component accessor,.
Definition at line 153 of file nnet-nnet.cc.
References Nnet::components_.
Referenced by Nnet::AppendNnet(), kaldi::ConvertNnet1ToNnet2(), Rbm::InitData(), main(), Nnet::Nnet(), Nnet::NumComponents(), Nnet::operator=(), Nnet::ResetStreams(), Nnet::SetDropoutRate(), Nnet::SetSeqLengths(), and Nnet::SetTrainOptions().
Component accessor,.
Definition at line 157 of file nnet-nnet.cc.
References Nnet::components_.
Get the gradient stored in the network,.
Definition at line 221 of file nnet-nnet.cc.
References Nnet::components_, UpdatableComponent::GetGradient(), rnnlm::i, KALDI_ASSERT, Nnet::NumParams(), UpdatableComponent::NumParams(), VectorBase< Real >::Range(), and Vector< Real >::Resize().
Referenced by Nnet::BackpropagateBuffer().
const Component & GetLastComponent | ( | ) | const |
LastComponent accessor,.
Definition at line 161 of file nnet-nnet.cc.
References Nnet::components_, and Nnet::NumComponents().
Referenced by main(), and Nnet::NumComponents().
Component & GetLastComponent | ( | ) |
LastComponent accessor,.
Definition at line 165 of file nnet-nnet.cc.
References Nnet::components_, and Nnet::NumComponents().
Get the network weights in a supervector,.
Definition at line 237 of file nnet-nnet.cc.
References Nnet::components_, UpdatableComponent::GetParams(), rnnlm::i, KALDI_ASSERT, Nnet::NumParams(), UpdatableComponent::NumParams(), VectorBase< Real >::Range(), and Vector< Real >::Resize().
Referenced by Nnet::BackpropagateBuffer(), Nnet::Check(), and MultiBasisComponent::GetParams().
|
inline |
Get training hyper-parameters from the network,.
Definition at line 163 of file nnet-nnet.h.
References Nnet::opts_.
std::string Info | ( | ) | const |
Create string with human readable description of the nnet,.
Definition at line 386 of file nnet-nnet.cc.
References Nnet::components_, rnnlm::i, Nnet::InputDim(), Nnet::NumComponents(), Nnet::NumParams(), Nnet::OutputDim(), and Component::TypeToMarker().
Referenced by Nnet::BackpropagateBuffer(), MultiBasisComponent::Info(), and main().
std::string InfoBackPropagate | ( | bool | header = true | ) | const |
Create string with back-propagation-buffer statistics,.
Definition at line 443 of file nnet-nnet.cc.
References Nnet::backpropagate_buf_, Nnet::components_, rnnlm::i, Component::kMultiBasisComponent, Component::kParallelComponent, kaldi::nnet1::MomentStatistics(), Nnet::NumComponents(), and Component::TypeToMarker().
Referenced by Nnet::BackpropagateBuffer(), MultiBasisComponent::InfoBackPropagate(), and main().
std::string InfoGradient | ( | bool | header = true | ) | const |
Create string with per-component gradient statistics,.
Definition at line 407 of file nnet-nnet.cc.
References Nnet::components_, rnnlm::i, Nnet::NumComponents(), and Component::TypeToMarker().
Referenced by Nnet::BackpropagateBuffer(), MultiBasisComponent::InfoGradient(), and main().
std::string InfoPropagate | ( | bool | header = true | ) | const |
Create string with propagation-buffer statistics,.
Definition at line 420 of file nnet-nnet.cc.
References Nnet::components_, rnnlm::i, Component::kMultiBasisComponent, Component::kParallelComponent, kaldi::nnet1::MomentStatistics(), Nnet::NumComponents(), Nnet::propagate_buf_, and Component::TypeToMarker().
Referenced by Nnet::BackpropagateBuffer(), MultiBasisComponent::InfoPropagate(), and main().
void Init | ( | const std::string & | proto_file | ) |
Initialize the Nnet from the prototype,.
Definition at line 301 of file nnet-nnet.cc.
References Nnet::AppendComponentPointer(), Nnet::Check(), Input::Close(), Component::Init(), KALDI_ASSERT, KALDI_VLOG, and Input::Stream().
Referenced by Nnet::BackpropagateBuffer(), MultiBasisComponent::InitData(), ParallelComponent::InitData(), and main().
int32 InputDim | ( | ) | const |
Dimensionality on network input (input feature dim.),.
Definition at line 148 of file nnet-nnet.cc.
References Nnet::components_, and KALDI_ASSERT.
Referenced by MultiBasisComponent::BackpropagateFnc(), Nnet::Info(), MultiBasisComponent::InitData(), Rbm::InitData(), main(), MultiBasisComponent::PropagateFnc(), and MultiBasisComponent::ReadData().
|
inline |
Returns the number of 'Components' which form the NN.
Typically a NN layer is composed of 2 components: the <AffineTransform> with trainable parameters and a non-linearity like <Sigmoid> or <Softmax>. Usually there are 2x more Components than the NN layers.
Definition at line 66 of file nnet-nnet.h.
References Nnet::AppendComponent(), Nnet::AppendComponentPointer(), Nnet::AppendNnet(), Nnet::components_, Nnet::GetComponent(), Nnet::GetLastComponent(), Nnet::RemoveComponent(), Nnet::RemoveLastComponent(), Nnet::ReplaceComponent(), and Nnet::SwapComponent().
Referenced by Nnet::AppendNnet(), Nnet::Backpropagate(), kaldi::ConvertNnet1ToNnet2(), Nnet::Destroy(), Nnet::Feedforward(), Nnet::GetLastComponent(), Nnet::Info(), Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), main(), Nnet::Nnet(), Nnet::operator=(), Nnet::Propagate(), Nnet::Read(), Nnet::RemoveLastComponent(), Nnet::ResetStreams(), Nnet::SetDropoutRate(), Nnet::SetSeqLengths(), Nnet::SetTrainOptions(), and Nnet::Write().
int32 NumParams | ( | ) | const |
Get the number of parameters in the network,.
Definition at line 210 of file nnet-nnet.cc.
References Nnet::components_, and rnnlm::n.
Referenced by Nnet::BackpropagateBuffer(), Nnet::GetGradient(), Nnet::GetParams(), Nnet::Info(), MultiBasisComponent::NumParams(), Nnet::SetParams(), and MultiBasisComponent::SetParams().
Definition at line 51 of file nnet-nnet.cc.
References Nnet::backpropagate_buf_, Nnet::Check(), Nnet::components_, Component::Copy(), Nnet::Destroy(), Nnet::GetComponent(), rnnlm::i, Nnet::NumComponents(), Nnet::opts_, Nnet::propagate_buf_, and Nnet::SetTrainOptions().
int32 OutputDim | ( | ) | const |
Dimensionality of network outputs (posteriors | bn-features | etc.),.
Definition at line 143 of file nnet-nnet.cc.
References Nnet::components_, and KALDI_ASSERT.
Referenced by Nnet::Info(), MultiBasisComponent::InitData(), main(), and MultiBasisComponent::ReadData().
void Propagate | ( | const CuMatrixBase< BaseFloat > & | in, |
CuMatrix< BaseFloat > * | out | ||
) |
Perform forward pass through the network,.
Forward propagation through the network, (from first component to last).
Definition at line 70 of file nnet-nnet.cc.
References Nnet::components_, rnnlm::i, Nnet::NumComponents(), and Nnet::propagate_buf_.
Referenced by main(), and MultiBasisComponent::PropagateFnc().
Access to the forward-pass buffers.
Definition at line 105 of file nnet-nnet.h.
References Nnet::propagate_buf_.
void Read | ( | const std::string & | rxfilename | ) |
Read Nnet from 'rxfilename',.
I/O wrapper for converting 'rxfilename' to 'istream',.
Definition at line 333 of file nnet-nnet.cc.
References Input::Close(), KALDI_WARN, Nnet::NumComponents(), and Input::Stream().
Referenced by Nnet::BackpropagateBuffer(), MultiBasisComponent::InitData(), ParallelComponent::InitData(), Rbm::InitData(), main(), ParallelComponent::ReadData(), and MultiBasisComponent::ReadData().
void Read | ( | std::istream & | in, |
bool | binary | ||
) |
Read Nnet from 'istream',.
Definition at line 345 of file nnet-nnet.cc.
References Nnet::AppendComponentPointer(), Nnet::Check(), Nnet::components_, Component::InputDim(), KALDI_ERR, Nnet::NumComponents(), and Component::Read().
void RemoveComponent | ( | int32 | c | ) |
Remove c'th component,.
Definition at line 199 of file nnet-nnet.cc.
References Nnet::Check(), and Nnet::components_.
Referenced by main(), Nnet::NumComponents(), and Nnet::RemoveLastComponent().
void RemoveLastComponent | ( | ) |
Remove the last of the Components,.
Definition at line 206 of file nnet-nnet.cc.
References Nnet::NumComponents(), and Nnet::RemoveComponent().
Referenced by main(), and Nnet::NumComponents().
Replace c'th component in 'this' Nnet (deep copy),.
Definition at line 169 of file nnet-nnet.cc.
References Nnet::Check(), Nnet::components_, and Component::Copy().
Referenced by Nnet::NumComponents().
void ResetStreams | ( | const std::vector< int32 > & | stream_reset_flag | ) |
Reset streams in multi-stream training,.
Definition at line 281 of file nnet-nnet.cc.
References Nnet::GetComponent(), Component::IsMultistream(), Nnet::NumComponents(), and MultistreamComponent::ResetStreams().
Referenced by Nnet::BackpropagateBuffer(), and main().
void SetDropoutRate | ( | BaseFloat | r | ) |
Set the dropout rate.
Definition at line 268 of file nnet-nnet.cc.
References Nnet::GetComponent(), Dropout::GetDropoutRate(), Component::GetType(), KALDI_LOG, Component::kDropout, Nnet::NumComponents(), and Dropout::SetDropoutRate().
Referenced by Nnet::BackpropagateBuffer(), and main().
void SetParams | ( | const VectorBase< BaseFloat > & | params | ) |
Set the network weights from a supervector,.
Definition at line 253 of file nnet-nnet.cc.
References Nnet::components_, VectorBase< Real >::Dim(), rnnlm::i, KALDI_ASSERT, Nnet::NumParams(), UpdatableComponent::NumParams(), VectorBase< Real >::Range(), and UpdatableComponent::SetParams().
Referenced by Nnet::BackpropagateBuffer(), and MultiBasisComponent::SetParams().
void SetSeqLengths | ( | const std::vector< int32 > & | sequence_lengths | ) |
Set sequence length in LSTM multi-stream training,.
Definition at line 291 of file nnet-nnet.cc.
References Nnet::GetComponent(), Component::IsMultistream(), Nnet::NumComponents(), and MultistreamComponent::SetSeqLengths().
Referenced by Nnet::BackpropagateBuffer(), and main().
void SetTrainOptions | ( | const NnetTrainOptions & | opts | ) |
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition at line 508 of file nnet-nnet.cc.
References Nnet::GetComponent(), Component::IsUpdatable(), Nnet::NumComponents(), and Nnet::opts_.
Referenced by Nnet::BackpropagateBuffer(), main(), Nnet::Nnet(), Nnet::operator=(), and MultiBasisComponent::SetTrainOptions().
Swap c'th component with the pointer,.
Definition at line 175 of file nnet-nnet.cc.
References Nnet::Check(), and Nnet::components_.
Referenced by Nnet::NumComponents().
void Write | ( | const std::string & | wxfilename, |
bool | binary | ||
) | const |
Write Nnet to 'wxfilename',.
I/O wrapper for converting 'wxfilename' to 'ostream',.
Definition at line 367 of file nnet-nnet.cc.
References Output::Close(), and Output::Stream().
Referenced by Nnet::BackpropagateBuffer(), main(), and MultiBasisComponent::WriteData().
void Write | ( | std::ostream & | out, |
bool | binary | ||
) | const |
Write Nnet to 'ostream',.
Definition at line 374 of file nnet-nnet.cc.
References Nnet::Check(), Nnet::components_, rnnlm::i, Nnet::NumComponents(), and kaldi::WriteToken().
Buffers for backward pass (on demand initialization),.
Definition at line 175 of file nnet-nnet.h.
Referenced by Nnet::Backpropagate(), Nnet::BackpropagateBuffer(), Nnet::Destroy(), Nnet::InfoBackPropagate(), Nnet::Nnet(), and Nnet::operator=().
|
private |
Vector which contains all the components composing the neural network, the components are for example: AffineTransform, Sigmoid, Softmax.
Definition at line 170 of file nnet-nnet.h.
Referenced by Nnet::AppendComponent(), Nnet::AppendComponentPointer(), Nnet::Backpropagate(), Nnet::Check(), Nnet::Destroy(), Nnet::Feedforward(), Nnet::GetComponent(), Nnet::GetGradient(), Nnet::GetLastComponent(), Nnet::GetParams(), Nnet::Info(), Nnet::InfoBackPropagate(), Nnet::InfoGradient(), Nnet::InfoPropagate(), Nnet::InputDim(), Nnet::Nnet(), Nnet::NumComponents(), Nnet::NumParams(), Nnet::operator=(), Nnet::OutputDim(), Nnet::Propagate(), Nnet::Read(), Nnet::RemoveComponent(), Nnet::ReplaceComponent(), Nnet::SetParams(), Nnet::SwapComponent(), and Nnet::Write().
|
private |
Option class with hyper-parameters passed to UpdatableComponent(s)
Definition at line 178 of file nnet-nnet.h.
Referenced by Nnet::GetTrainOptions(), Nnet::Nnet(), Nnet::operator=(), and Nnet::SetTrainOptions().
Buffers for forward pass (on demand initialization),.
Definition at line 173 of file nnet-nnet.h.
Referenced by Nnet::Backpropagate(), Nnet::Destroy(), Nnet::InfoPropagate(), Nnet::Nnet(), Nnet::operator=(), Nnet::Propagate(), and Nnet::PropagateBuffer().