21 #ifndef KALDI_NNET_NNET_RBM_H_ 22 #define KALDI_NNET_NNET_RBM_H_ 61 virtual RbmNodeType
VisType()
const = 0;
62 virtual RbmNodeType
HidType()
const = 0;
64 virtual void WriteAsNnet(std::ostream& os,
bool binary)
const = 0;
106 return new Rbm(*
this);
115 std::string vis_type;
116 std::string hid_type;
117 float vis_bias_mean = 0.0, vis_bias_range = 0.0,
118 hid_bias_mean = 0.0, hid_bias_range = 0.0,
120 std::string vis_bias_cmvn_file;
123 while (is >> std::ws, !is.eof()) {
125 if (token ==
"<VisibleType>")
ReadToken(is,
false, &vis_type);
126 else if (token ==
"<HiddenType>")
ReadToken(is,
false, &hid_type);
127 else if (token ==
"<VisibleBiasMean>")
ReadBasicType(is,
false, &vis_bias_mean);
128 else if (token ==
"<VisibleBiasRange>")
ReadBasicType(is,
false, &vis_bias_range);
129 else if (token ==
"<HiddenBiasMean>")
ReadBasicType(is,
false, &hid_bias_mean);
130 else if (token ==
"<HiddenBiasRange>")
ReadBasicType(is,
false, &hid_bias_range);
131 else if (token ==
"<ParamStddev>")
ReadBasicType(is,
false, ¶m_stddev);
132 else if (token ==
"<VisibleBiasCmvnFilename>")
ReadToken(is,
false, &vis_bias_cmvn_file);
133 else KALDI_ERR <<
"Unknown token " << token <<
" Typo in config?";
138 else if (vis_type ==
"gauss" || vis_type ==
"Gaussian") vis_type_ =
RbmBase::Gaussian;
139 else KALDI_ERR <<
"Wrong <VisibleType>" << vis_type;
142 else if (hid_type ==
"gauss" || hid_type ==
"Gaussian") hid_type_ =
RbmBase::Gaussian;
143 else KALDI_ERR <<
"Wrong <HiddenType>" << hid_type;
153 RandUniform(hid_bias_mean, hid_bias_range, &hid_bias_);
155 if (vis_bias_cmvn_file ==
"") {
157 RandUniform(vis_bias_mean, vis_bias_range, &vis_bias_);
159 KALDI_LOG <<
"Initializing from <VisibleBiasCmvnFilename> " 160 << vis_bias_cmvn_file;
163 cmvn.
Read(vis_bias_cmvn_file);
169 p.ApplyFloor(0.0001);
170 p.ApplyCeiling(0.9999);
173 for (
int32 d = 0;
d < p.Dim();
d++) {
174 logit_p(
d) =
Log(p(
d)) -
Log(1.0 - p(
d));
183 std::string vis_node_type, hid_node_type;
187 if (vis_node_type ==
"bern") {
189 }
else if (vis_node_type ==
"gauss") {
192 if (hid_node_type ==
"bern") {
194 }
else if (hid_node_type ==
"gauss") {
198 vis_hid_.Read(is, binary);
199 vis_bias_.Read(is, binary);
200 hid_bias_.Read(is, binary);
212 default :
KALDI_ERR <<
"Unknown type " << vis_type_;
217 default :
KALDI_ERR <<
"Unknown type " << hid_type_;
219 vis_hid_.Write(os, binary);
220 vis_bias_.Write(os, binary);
221 hid_bias_.Write(os, binary);
244 <<
" data:" << hid_state.
NumCols();
258 vis_probs->
Sigmoid(*vis_probs);
276 if ( vis_hid_corr_.NumRows() != vis_hid_.NumRows() ||
277 vis_hid_corr_.NumCols() != vis_hid_.NumCols() ||
278 vis_bias_corr_.Dim() != vis_bias_.Dim() ||
279 hid_bias_corr_.Dim() != hid_bias_.Dim() ) {
280 vis_hid_corr_.Resize(vis_hid_.NumRows(), vis_hid_.NumCols(),
kSetZero);
281 vis_bias_corr_.Resize(vis_bias_.Dim(),
kSetZero);
282 hid_bias_corr_.Resize(hid_bias_.Dim(),
kSetZero);
323 if (pos_vis_std * 2 < neg_vis_std) {
325 BaseFloat scale = pos_vis_std / neg_vis_std;
326 vis_hid_.Scale(scale);
327 vis_bias_.Scale(scale);
328 hid_bias_.Scale(scale);
332 vis_hid_corr_.SetZero();
333 vis_bias_corr_.SetZero();
334 hid_bias_corr_.SetZero();
336 KALDI_WARN <<
"Mismatch between pos_vis and neg_vis variances, " 337 <<
"danger of weight explosion." 338 <<
" a) Reducing weights with scale " << scale
340 <<
" [pos_vis_std:" << pos_vis_std
341 <<
",neg_vis_std:" << neg_vis_std <<
"]";
365 vis_hid_corr_.AddMatMat(-lr/N, neg_hid,
kTrans, neg_vis,
kNoTrans, mmt);
366 vis_hid_corr_.AddMatMat(+lr/N, pos_hid,
kTrans, pos_vis,
kNoTrans, 1.0);
367 vis_hid_corr_.AddMat(-lr*l2, vis_hid_);
368 vis_hid_.AddMat(1.0, vis_hid_corr_);
375 vis_bias_corr_.AddRowSumMat(-lr/N, neg_vis, mmt);
376 vis_bias_corr_.AddRowSumMat(+lr/N, pos_vis, 1.0);
377 vis_bias_.AddVec(1.0, vis_bias_corr_, 1.0);
384 hid_bias_corr_.AddRowSumMat(-lr/N, neg_hid, mmt);
385 hid_bias_corr_.AddRowSumMat(+lr/N, pos_hid, 1.0);
386 hid_bias_.AddVec(1.0, hid_bias_corr_, 1.0);
402 if (!binary) os <<
"\n";
404 vis_hid_.Write(os, binary);
405 hid_bias_.Write(os, binary);
412 if (!binary) os <<
"\n";
433 #endif // KALDI_NNET_NNET_RBM_H_ This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void InitData(std::istream &is)
Virtual interface for initialization and I/O,.
void WriteAsNnet(std::ostream &os, bool binary) const
CuVector< BaseFloat > hid_bias_
Vector with biases.
Component * Copy() const
Copy component (deep copy),.
void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
Abstract interface for propagation/backpropagation.
CuVector< BaseFloat > hid_bias_corr_
Vector for bias updates.
int32 input_dim_
Data members,.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
RbmTrainOptions rbm_opts_
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
RbmNodeType VisType() const
void RbmUpdate(const CuMatrixBase< BaseFloat > &pos_vis, const CuMatrixBase< BaseFloat > &pos_hid, const CuMatrixBase< BaseFloat > &neg_vis, const CuMatrixBase< BaseFloat > &neg_hid)
void RandUniform(BaseFloat mu, BaseFloat range, CuMatrixBase< Real > *mat, struct RandomState *state=NULL)
Fill CuMatrix with random numbers (Uniform distribution): mu = the mean value, range = the 'width' of...
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
Backward pass transformation (to be implemented by descending class...)
ComponentType
Component type identification mechanism,.
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
void Reconstruct(const CuMatrixBase< BaseFloat > &hid_state, CuMatrix< BaseFloat > *vis_probs)
Adds shift to all the lines of the matrix (can be used for global mean normalization) ...
Real ComputeStdDev(const CuMatrixBase< Real > &mat)
Get the standard deviation of values in the matrix.
ComponentType GetType() const
Get Type Identification of the component,.
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
void Sigmoid(const CuMatrixBase< Real > &src)
Set each element to the sigmoid of the corresponding element of "src": element by element...
RbmBase(int32 dim_in, int32 dim_out)
virtual RbmNodeType VisType() const =0
void WriteData(std::ostream &os, bool binary) const
Writes the component content.
RbmNodeType HidType() const
virtual RbmNodeType HidType() const =0
int32 InputDim() const
Get the dimension of the input,.
void RandGauss(BaseFloat mu, BaseFloat sigma, CuMatrixBase< Real > *mat, struct RandomState *state=NULL)
Fill CuMatrix with random numbers (Gaussian distribution): mu = the mean value, sigma = standard devi...
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void AddMatMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType transA, const CuMatrixBase< Real > &B, MatrixTransposeType transB, Real beta)
C = alpha * A(^T)*B(^T) + beta * C.
const RbmTrainOptions & GetRbmTrainOptions() const
Get training hyper-parameters from the network.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
virtual void RbmUpdate(const CuMatrixBase< BaseFloat > &pos_vis, const CuMatrixBase< BaseFloat > &pos_hid, const CuMatrixBase< BaseFloat > &neg_vis, const CuMatrixBase< BaseFloat > &neg_hid)=0
virtual void WriteAsNnet(std::ostream &os, bool binary) const =0
int32 output_dim_
Dimension of the output of the Component,.
Matrix for CUDA computing.
MatrixIndexT NumCols() const
void Backpropagate(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
A class representing a vector.
#define KALDI_ASSERT(cond)
Rbm(int32 dim_in, int32 dim_out)
CuVector< BaseFloat > vis_bias_
Vector with biases.
const Component & GetComponent(int32 c) const
Component accessor,.
CuVector< BaseFloat > vis_bias_corr_
Vector for bias updates.
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Abstract class, building block of the network.
void SetRbmTrainOptions(const RbmTrainOptions &opts)
Set training hyper-parameters to the network and its UpdatableComponent(s)
CuMatrix< BaseFloat > vis_hid_corr_
Matrix for linearity updates.
int32 OutputDim() const
Get the dimension of the output,.
MatrixIndexT NumRows() const
Dimensions.
virtual void Reconstruct(const CuMatrixBase< BaseFloat > &hid_state, CuMatrix< BaseFloat > *vis_probs)=0
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
void CheckNanInf(const CuMatrixBase< Real > &mat, const char *msg="")
Check that matrix contains no nan or inf.
void ReadData(std::istream &is, bool binary)
Reads the component content.