22 #ifndef KALDI_NNET_NNET_COMPONENT_H_ 23 #define KALDI_NNET_NNET_COMPONENT_H_ 101 static ComponentType
MarkerToType(
const std::string &s);
117 virtual ComponentType
GetType()
const = 0;
156 void Write(std::ostream &os,
bool binary)
const;
159 virtual std::string
Info()
const {
return ""; }
183 virtual void ReadData(std::istream &is,
bool binary) { }
186 virtual void WriteData(std::ostream &os,
bool binary)
const { }
197 ComponentType t,
int32 input_dim,
int32 output_dim
212 learn_rate_coef_(1.0),
213 bias_learn_rate_coef_(1.0)
225 virtual int32 NumParams()
const = 0;
252 learn_rate_coef_ = val;
257 bias_learn_rate_coef_ = val;
261 virtual void InitData(std::istream &is) = 0;
292 sequence_lengths_ = sequence_lengths;
296 return std::max<int32>(1, sequence_lengths_.size());
316 <<
" component. The input-dim is " <<
input_dim_ 317 <<
", the data had " << in.
NumCols() <<
" dims.";
332 <<
", the dim of output derivatives " << out_diff.
NumCols();
355 #endif // KALDI_NNET_NNET_COMPONENT_H_ This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
virtual void SetTrainOptions(const NnetTrainOptions &opts)
Set the training options to the component,.
virtual std::string Info() const
Print some additional info (after <ComponentName> and the dims),.
NnetTrainOptions opts_
Option-class with training hyper-parameters,.
int32 input_dim_
Data members,.
BaseFloat bias_learn_rate_coef_
Scalar applied to learning rate for bias (to be used in ::Update method),.
BaseFloat learn_rate_coef_
Scalar applied to learning rate for weight matrices (to be used in ::Update method),.
bool IsUpdatable() const
Check if contains trainable parameters,.
Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hype...
Component(int32 input_dim, int32 output_dim)
Generic interface of a component,.
void Backpropagate(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward-pass propagation 'out_diff' -> 'in_diff'.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
static Component * Init(const std::string &conf_line)
Initialize component from a line in config file,.
static Component * Read(std::istream &is, bool binary)
Read the component from a stream (static method),.
virtual void SetLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient,.
ComponentType
Component type identification mechanism,.
virtual void ReadData(std::istream &is, bool binary)
Reads the component content.
virtual bool IsUpdatable() const
Check if componeny has 'Updatable' interface (trainable components),.
MultistreamComponent(int32 input_dim, int32 output_dim)
A pair of type and marker,.
virtual void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
const Component::ComponentType key
virtual void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Optional function to reset the transfer of context (not used for BLSTMs.
virtual void SetBiasLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient for bias,.
void Write(std::ostream &os, bool binary) const
Write the component to a stream,.
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward-pass propagation 'in' -> 'out',.
int32 InputDim() const
Get the dimension of the input,.
static ComponentType MarkerToType(const std::string &s)
Converts marker to component type (case insensitive),.
UpdatableComponent(int32 input_dim, int32 output_dim)
virtual Component * Copy() const =0
Copy component (deep copy),.
virtual void WriteData(std::ostream &os, bool binary) const
Writes the component content.
const NnetTrainOptions & GetTrainOptions() const
Get the training options from the component,.
virtual bool IsMultistream() const
Check if component has 'Recurrent' interface (trainable and recurrent),.
virtual void InitData(std::istream &is)
Virtual interface for initialization and I/O,.
virtual ~UpdatableComponent()
virtual std::string InfoGradient() const
Print some additional info about gradient (after <...> and dims),.
Class MultistreamComponent is an extension of UpdatableComponent for recurrent networks, which are trained with parallel sequences.
int32 output_dim_
Dimension of the output of the Component,.
virtual void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)=0
Abstract interface for propagation/backpropagation.
Matrix for CUDA computing.
MatrixIndexT NumCols() const
#define KALDI_ASSERT(cond)
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
Abstract class, building block of the network.
std::vector< int32 > sequence_lengths_
int32 OutputDim() const
Get the dimension of the output,.
MatrixIndexT NumRows() const
Dimensions.
virtual void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)=0
Backward pass transformation (to be implemented by descending class...)
Provides a vector abstraction class.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
static Component * NewComponentOfType(ComponentType t, int32 input_dim, int32 output_dim)
Private members (descending classes cannot call this),.
bool IsMultistream() const
Check if component has 'Recurrent' interface (trainable and recurrent),.