All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Rbm Class Reference

#include <nnet-rbm.h>

Inheritance diagram for Rbm:
Collaboration diagram for Rbm:

Public Member Functions

 Rbm (int32 dim_in, int32 dim_out)
 
 ~Rbm ()
 
ComponentCopy () const
 Copy component (deep copy),. More...
 
ComponentType GetType () const
 Get Type Identification of the component,. More...
 
void InitData (std::istream &is)
 Virtual interface for initialization and I/O,. More...
 
void ReadData (std::istream &is, bool binary)
 Reads the component content. More...
 
void WriteData (std::ostream &os, bool binary) const
 Writes the component content. More...
 
void PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
 Abstract interface for propagation/backpropagation. More...
 
void Reconstruct (const CuMatrixBase< BaseFloat > &hid_state, CuMatrix< BaseFloat > *vis_probs)
 
void RbmUpdate (const CuMatrixBase< BaseFloat > &pos_vis, const CuMatrixBase< BaseFloat > &pos_hid, const CuMatrixBase< BaseFloat > &neg_vis, const CuMatrixBase< BaseFloat > &neg_hid)
 
RbmNodeType VisType () const
 
RbmNodeType HidType () const
 
void WriteAsNnet (std::ostream &os, bool binary) const
 
- Public Member Functions inherited from RbmBase
 RbmBase (int32 dim_in, int32 dim_out)
 
void SetRbmTrainOptions (const RbmTrainOptions &opts)
 Set training hyper-parameters to the network and its UpdatableComponent(s) More...
 
const RbmTrainOptionsGetRbmTrainOptions () const
 Get training hyper-parameters from the network. More...
 
- Public Member Functions inherited from Component
 Component (int32 input_dim, int32 output_dim)
 Generic interface of a component,. More...
 
virtual ~Component ()
 
virtual bool IsUpdatable () const
 Check if componeny has 'Updatable' interface (trainable components),. More...
 
virtual bool IsMultistream () const
 Check if component has 'Recurrent' interface (trainable and recurrent),. More...
 
int32 InputDim () const
 Get the dimension of the input,. More...
 
int32 OutputDim () const
 Get the dimension of the output,. More...
 
void Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
 Perform forward-pass propagation 'in' -> 'out',. More...
 
void Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
 Perform backward-pass propagation 'out_diff' -> 'in_diff'. More...
 
void Write (std::ostream &os, bool binary) const
 Write the component to a stream,. More...
 
virtual std::string Info () const
 Print some additional info (after <ComponentName> and the dims),. More...
 
virtual std::string InfoGradient () const
 Print some additional info about gradient (after <...> and dims),. More...
 

Protected Attributes

CuMatrix< BaseFloatvis_hid_
 Matrix with neuron weights. More...
 
CuVector< BaseFloatvis_bias_
 Vector with biases. More...
 
CuVector< BaseFloathid_bias_
 Vector with biases. More...
 
CuMatrix< BaseFloatvis_hid_corr_
 Matrix for linearity updates. More...
 
CuVector< BaseFloatvis_bias_corr_
 Vector for bias updates. More...
 
CuVector< BaseFloathid_bias_corr_
 Vector for bias updates. More...
 
RbmNodeType vis_type_
 
RbmNodeType hid_type_
 
- Protected Attributes inherited from RbmBase
RbmTrainOptions rbm_opts_
 
- Protected Attributes inherited from Component
int32 input_dim_
 Data members,. More...
 
int32 output_dim_
 Dimension of the output of the Component,. More...
 

Additional Inherited Members

- Public Types inherited from RbmBase
enum  RbmNodeType { Bernoulli, Gaussian }
 
- Public Types inherited from Component
enum  ComponentType {
  kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform,
  kConvolutionalComponent, kConvolutional2DComponent, kLstmProjected, kBlstmProjected,
  kRecurrentComponent, kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax,
  kBlockSoftmax, kSigmoid, kTanh, kParametricRelu,
  kDropout, kLengthNormComponent, kTranform = 0x0400, kRbm,
  kSplice, kCopy, kTranspose, kBlockLinearity,
  kAddShift, kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent,
  kSimpleSentenceAveragingComponent, kAveragePoolingComponent, kAveragePooling2DComponent, kMaxPoolingComponent,
  kMaxPooling2DComponent, kFramePoolingComponent, kParallelComponent, kMultiBasisComponent
}
 Component type identification mechanism,. More...
 
- Static Public Member Functions inherited from Component
static const char * TypeToMarker (ComponentType t)
 Converts component type to marker,. More...
 
static ComponentType MarkerToType (const std::string &s)
 Converts marker to component type (case insensitive),. More...
 
static ComponentInit (const std::string &conf_line)
 Initialize component from a line in config file,. More...
 
static ComponentRead (std::istream &is, bool binary)
 Read the component from a stream (static method),. More...
 
- Static Public Attributes inherited from Component
static const struct key_value kMarkerMap []
 The table with pairs of Component types and markers (defined in nnet-component.cc),. More...
 

Detailed Description

Definition at line 96 of file nnet-rbm.h.

Constructor & Destructor Documentation

Rbm ( int32  dim_in,
int32  dim_out 
)
inline

Definition at line 98 of file nnet-rbm.h.

Referenced by Rbm::Copy().

98  :
99  RbmBase(dim_in, dim_out)
100  { }
RbmBase(int32 dim_in, int32 dim_out)
Definition: nnet-rbm.h:42
~Rbm ( )
inline

Definition at line 102 of file nnet-rbm.h.

103  { }

Member Function Documentation

Component* Copy ( ) const
inlinevirtual

Copy component (deep copy),.

Implements Component.

Definition at line 105 of file nnet-rbm.h.

References Rbm::Rbm().

105  {
106  return new Rbm(*this);
107  }
Rbm(int32 dim_in, int32 dim_out)
Definition: nnet-rbm.h:98
ComponentType GetType ( ) const
inlinevirtual

Get Type Identification of the component,.

Implements Component.

Definition at line 109 of file nnet-rbm.h.

References Component::kRbm.

109  {
110  return kRbm;
111  }
RbmNodeType HidType ( ) const
inlinevirtual

Implements RbmBase.

Definition at line 393 of file nnet-rbm.h.

References Rbm::hid_type_.

Referenced by Rbm::WriteAsNnet().

393  {
394  return hid_type_;
395  }
RbmNodeType hid_type_
Definition: nnet-rbm.h:425
void InitData ( std::istream &  is)
inlinevirtual

Virtual interface for initialization and I/O,.

Initialize internal data of a component

Reimplemented from Component.

Definition at line 113 of file nnet-rbm.h.

References RbmBase::Bernoulli, rnnlm::d, CuVectorBase< Real >::Dim(), RbmBase::Gaussian, Nnet::GetComponent(), Rbm::hid_bias_, Rbm::hid_type_, Nnet::InputDim(), Component::InputDim(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, kaldi::Log(), Component::OutputDim(), kaldi::nnet1::RandGauss(), kaldi::nnet1::RandUniform(), Nnet::Read(), kaldi::ReadBasicType(), kaldi::ReadToken(), CuVector< Real >::Resize(), CuMatrix< Real >::Resize(), Rbm::vis_bias_, Rbm::vis_hid_, and Rbm::vis_type_.

113  {
114  // define options,
115  std::string vis_type;
116  std::string hid_type;
117  float vis_bias_mean = 0.0, vis_bias_range = 0.0,
118  hid_bias_mean = 0.0, hid_bias_range = 0.0,
119  param_stddev = 0.1;
120  std::string vis_bias_cmvn_file; // initialize biases to logit(p_active)
121  // parse config,
122  std::string token;
123  while (is >> std::ws, !is.eof()) {
124  ReadToken(is, false, &token);
125  if (token == "<VisibleType>") ReadToken(is, false, &vis_type);
126  else if (token == "<HiddenType>") ReadToken(is, false, &hid_type);
127  else if (token == "<VisibleBiasMean>") ReadBasicType(is, false, &vis_bias_mean);
128  else if (token == "<VisibleBiasRange>") ReadBasicType(is, false, &vis_bias_range);
129  else if (token == "<HiddenBiasMean>") ReadBasicType(is, false, &hid_bias_mean);
130  else if (token == "<HiddenBiasRange>") ReadBasicType(is, false, &hid_bias_range);
131  else if (token == "<ParamStddev>") ReadBasicType(is, false, &param_stddev);
132  else if (token == "<VisibleBiasCmvnFilename>") ReadToken(is, false, &vis_bias_cmvn_file);
133  else KALDI_ERR << "Unknown token " << token << " Typo in config?";
134  }
135 
136  // Translate the 'node' types,
137  if (vis_type == "bern" || vis_type == "Bernoulli") vis_type_ = RbmBase::Bernoulli;
138  else if (vis_type == "gauss" || vis_type == "Gaussian") vis_type_ = RbmBase::Gaussian;
139  else KALDI_ERR << "Wrong <VisibleType>" << vis_type;
140  //
141  if (hid_type == "bern" || hid_type == "Bernoulli") hid_type_ = RbmBase::Bernoulli;
142  else if (hid_type == "gauss" || hid_type == "Gaussian") hid_type_ = RbmBase::Gaussian;
143  else KALDI_ERR << "Wrong <HiddenType>" << hid_type;
144 
145  //
146  // Initialize trainable parameters,
147  //
148  // visible-hidden connections,
150  RandGauss(0.0, param_stddev, &vis_hid_);
151  // hidden-bias,
153  RandUniform(hid_bias_mean, hid_bias_range, &hid_bias_);
154  // visible-bias,
155  if (vis_bias_cmvn_file == "") {
157  RandUniform(vis_bias_mean, vis_bias_range, &vis_bias_);
158  } else {
159  KALDI_LOG << "Initializing from <VisibleBiasCmvnFilename> "
160  << vis_bias_cmvn_file;
161  // Reading Nnet with 'global-cmvn' components,
162  Nnet cmvn;
163  cmvn.Read(vis_bias_cmvn_file);
164  KALDI_ASSERT(InputDim() == cmvn.InputDim());
165  // The parameters from <AddShift> correspond to 'negative' mean values,
166  Vector<BaseFloat> p(cmvn.InputDim());
167  dynamic_cast<AddShift&>(cmvn.GetComponent(0)).GetParams(&p);
168  p.Scale(-1.0); // 'un-do' negation of mean values,
169  p.ApplyFloor(0.0001);
170  p.ApplyCeiling(0.9999);
171  // Getting the logit,
172  Vector<BaseFloat> logit_p(p.Dim());
173  for (int32 d = 0; d < p.Dim(); d++) {
174  logit_p(d) = Log(p(d)) - Log(1.0 - p(d));
175  }
176  vis_bias_ = logit_p;
178  }
179  }
CuVector< BaseFloat > hid_bias_
Vector with biases.
Definition: nnet-rbm.h:418
RbmNodeType hid_type_
Definition: nnet-rbm.h:425
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
RbmNodeType vis_type_
Definition: nnet-rbm.h:424
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
void RandUniform(BaseFloat mu, BaseFloat range, CuMatrixBase< Real > *mat, struct RandomState *state=NULL)
Fill CuMatrix with random numbers (Uniform distribution): mu = the mean value, range = the 'width' of...
Definition: nnet-utils.h:188
int32 OutputDim() const
Get the dimension of the output,.
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
double Log(double x)
Definition: kaldi-math.h:100
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:47
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
#define KALDI_ERR
Definition: kaldi-error.h:127
void RandGauss(BaseFloat mu, BaseFloat sigma, CuMatrixBase< Real > *mat, struct RandomState *state=NULL)
Fill CuMatrix with random numbers (Gaussian distribution): mu = the mean value, sigma = standard devi...
Definition: nnet-utils.h:164
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
CuVector< BaseFloat > vis_bias_
Vector with biases.
Definition: nnet-rbm.h:417
int32 InputDim() const
Get the dimension of the input,.
#define KALDI_LOG
Definition: kaldi-error.h:133
void PropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
)
inlinevirtual

Abstract interface for propagation/backpropagation.

Forward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 226 of file nnet-rbm.h.

References CuMatrixBase< Real >::AddMatMat(), CuMatrixBase< Real >::AddVecToRows(), RbmBase::Bernoulli, Rbm::hid_bias_, Rbm::hid_type_, kaldi::kNoTrans, kaldi::kTrans, CuMatrixBase< Real >::Sigmoid(), and Rbm::vis_hid_.

227  {
228  // pre-fill with bias
229  out->AddVecToRows(1.0, hid_bias_, 0.0);
230  // multiply by weights^t
231  out->AddMatMat(1.0, in, kNoTrans, vis_hid_, kTrans, 1.0);
232  // optionally apply sigmoid
233  if (hid_type_ == RbmBase::Bernoulli) {
234  out->Sigmoid(*out);
235  }
236  }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1112
CuVector< BaseFloat > hid_bias_
Vector with biases.
Definition: nnet-rbm.h:418
RbmNodeType hid_type_
Definition: nnet-rbm.h:425
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
void Sigmoid(const CuMatrixBase< Real > &src)
Set each element to the sigmoid of the corresponding element of "src": element by element...
Definition: cu-matrix.cc:1385
void AddMatMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType transA, const CuMatrixBase< Real > &B, MatrixTransposeType transB, Real beta)
C = alpha * A(^T)*B(^T) + beta * C.
Definition: cu-matrix.cc:1142
void RbmUpdate ( const CuMatrixBase< BaseFloat > &  pos_vis,
const CuMatrixBase< BaseFloat > &  pos_hid,
const CuMatrixBase< BaseFloat > &  neg_vis,
const CuMatrixBase< BaseFloat > &  neg_hid 
)
inlinevirtual

Implements RbmBase.

Definition at line 262 of file nnet-rbm.h.

References CuMatrixBase< Real >::AddMat(), CuMatrixBase< Real >::AddMatMat(), CuVectorBase< Real >::AddRowSumMat(), CuVectorBase< Real >::AddVec(), kaldi::nnet1::CheckNanInf(), kaldi::nnet1::ComputeStdDev(), CuVectorBase< Real >::Dim(), RbmBase::Gaussian, Rbm::hid_bias_, Rbm::hid_bias_corr_, Component::input_dim_, KALDI_ASSERT, KALDI_WARN, kaldi::kNoTrans, kaldi::kSetZero, kaldi::kTrans, RbmTrainOptions::l2_penalty, RbmTrainOptions::learn_rate, RbmTrainOptions::momentum, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), Component::output_dim_, RbmBase::rbm_opts_, CuVector< Real >::Resize(), CuMatrix< Real >::Resize(), CuVectorBase< Real >::Scale(), CuMatrixBase< Real >::Scale(), CuVectorBase< Real >::SetZero(), CuMatrixBase< Real >::SetZero(), Rbm::vis_bias_, Rbm::vis_bias_corr_, Rbm::vis_hid_, Rbm::vis_hid_corr_, and Rbm::vis_type_.

265  {
266  // dims
267  KALDI_ASSERT(pos_vis.NumRows() == pos_hid.NumRows() &&
268  pos_vis.NumRows() == neg_vis.NumRows() &&
269  pos_vis.NumRows() == neg_hid.NumRows() &&
270  pos_vis.NumCols() == neg_vis.NumCols() &&
271  pos_hid.NumCols() == neg_hid.NumCols() &&
272  pos_vis.NumCols() == input_dim_ &&
273  pos_hid.NumCols() == output_dim_);
274 
275  // lazy initialization of buffers
276  if ( vis_hid_corr_.NumRows() != vis_hid_.NumRows() ||
278  vis_bias_corr_.Dim() != vis_bias_.Dim() ||
279  hid_bias_corr_.Dim() != hid_bias_.Dim() ) {
283  }
284 
285  // ANTI-WEIGHT-EXPLOSION PROTECTION (Gaussian-Bernoulli RBM)
286  //
287  // in the following section we detect that the weights in
288  // Gaussian-Bernoulli RBM are almost exploding. The weight
289  // explosion is caused by large variance of the reconstructed data,
290  // which causes a feed-back loop that keeps increasing the weights.
291  //
292  // To avoid explosion, the standard-deviation of the visible-data
293  // and reconstructed-data should be about the same.
294  // The model is particularly sensitive at the very
295  // beginning of the CD-1 training.
296  //
297  // We compute the standard deviations on
298  // * 'A' : input mini-batch
299  // * 'B' : reconstruction.
300  // When 'B > 2*A', we stabilize the training in this way:
301  // 1. we scale down the weights and biases by 'A/B',
302  // 2. we shrink learning rate by 0.9x,
303  // 3. we reset the momentum buffer,
304  //
305  // A warning message is put to the log. In later stage
306  // the learning-rate returns back to its original value.
307  //
308  // To avoid the issue, we make sure that the weight-matrix
309  // is sensibly initialized.
310  //
311  if (vis_type_ == RbmBase::Gaussian) {
312  // check the data have no nan/inf:
313  CheckNanInf(pos_vis, "pos_vis");
314  CheckNanInf(pos_hid, "pos_hid");
315  CheckNanInf(neg_vis, "neg_vis");
316  CheckNanInf(neg_hid, "pos_hid");
317 
318  // get standard deviations of pos_vis and neg_vis:
319  BaseFloat pos_vis_std = ComputeStdDev(pos_vis);
320  BaseFloat neg_vis_std = ComputeStdDev(neg_vis);
321 
322  // monitor the standard deviation mismatch : data vs. reconstruction
323  if (pos_vis_std * 2 < neg_vis_std) {
324  // 1) scale-down the weights and biases
325  BaseFloat scale = pos_vis_std / neg_vis_std;
326  vis_hid_.Scale(scale);
327  vis_bias_.Scale(scale);
328  hid_bias_.Scale(scale);
329  // 2) reduce the learning rate
330  rbm_opts_.learn_rate *= 0.9;
331  // 3) reset the momentum buffers
335 
336  KALDI_WARN << "Mismatch between pos_vis and neg_vis variances, "
337  << "danger of weight explosion."
338  << " a) Reducing weights with scale " << scale
339  << " b) Lowering learning rate to " << rbm_opts_.learn_rate
340  << " [pos_vis_std:" << pos_vis_std
341  << ",neg_vis_std:" << neg_vis_std << "]";
342  return; /* i.e. don't update now, the update would be too BIG */
343  }
344  }
345  //
346  // End of weight-explosion check
347 
348 
349  // We use these training hyper-parameters
350  //
351  const BaseFloat lr = rbm_opts_.learn_rate;
352  const BaseFloat mmt = rbm_opts_.momentum;
353  const BaseFloat l2 = rbm_opts_.l2_penalty;
354 
355  // UPDATE vishid matrix
356  //
357  // vishidinc = momentum*vishidinc + ...
358  // epsilonw*( (posprods-negprods)/numcases - weightcost*vishid)
359  //
360  // vishidinc[t] = -(epsilonw/numcases)*negprods + momentum*vishidinc[t-1]
361  // +(epsilonw/numcases)*posprods
362  // -(epsilonw*weightcost)*vishid[t-1]
363  //
364  BaseFloat N = static_cast<BaseFloat>(pos_vis.NumRows());
365  vis_hid_corr_.AddMatMat(-lr/N, neg_hid, kTrans, neg_vis, kNoTrans, mmt);
366  vis_hid_corr_.AddMatMat(+lr/N, pos_hid, kTrans, pos_vis, kNoTrans, 1.0);
367  vis_hid_corr_.AddMat(-lr*l2, vis_hid_);
369 
370  // UPDATE visbias vector
371  //
372  // visbiasinc = momentum*visbiasinc +
373  // (epsilonvb/numcases)*(posvisact-negvisact);
374  //
375  vis_bias_corr_.AddRowSumMat(-lr/N, neg_vis, mmt);
376  vis_bias_corr_.AddRowSumMat(+lr/N, pos_vis, 1.0);
377  vis_bias_.AddVec(1.0, vis_bias_corr_, 1.0);
378 
379  // UPDATE hidbias vector
380  //
381  // hidbiasinc = momentum*hidbiasinc +
382  // (epsilonhb/numcases)*(poshidact-neghidact);
383  //
384  hid_bias_corr_.AddRowSumMat(-lr/N, neg_hid, mmt);
385  hid_bias_corr_.AddRowSumMat(+lr/N, pos_hid, 1.0);
386  hid_bias_.AddVec(1.0, hid_bias_corr_, 1.0);
387  }
void Scale(Real value)
Definition: cu-vector.cc:1105
CuVector< BaseFloat > hid_bias_
Vector with biases.
Definition: nnet-rbm.h:418
CuVector< BaseFloat > hid_bias_corr_
Vector for bias updates.
Definition: nnet-rbm.h:422
int32 input_dim_
Data members,.
RbmTrainOptions rbm_opts_
Definition: nnet-rbm.h:76
void Scale(Real value)
Definition: cu-matrix.cc:608
RbmNodeType vis_type_
Definition: nnet-rbm.h:424
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
void AddRowSumMat(Real alpha, const CuMatrixBase< Real > &mat, Real beta=1.0)
Sum the rows of the matrix, add to vector.
Definition: cu-vector.cc:1166
float BaseFloat
Definition: kaldi-types.h:29
Real ComputeStdDev(const CuMatrixBase< Real > &mat)
Get the standard deviation of values in the matrix.
Definition: nnet-utils.h:142
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:47
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:474
void AddMatMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType transA, const CuMatrixBase< Real > &B, MatrixTransposeType transB, Real beta)
C = alpha * A(^T)*B(^T) + beta * C.
Definition: cu-matrix.cc:1142
#define KALDI_WARN
Definition: kaldi-error.h:130
void SetZero()
Math operations.
Definition: cu-vector.cc:988
int32 output_dim_
Dimension of the output of the Component,.
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:939
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
CuVector< BaseFloat > vis_bias_
Vector with biases.
Definition: nnet-rbm.h:417
CuVector< BaseFloat > vis_bias_corr_
Vector for bias updates.
Definition: nnet-rbm.h:421
CuMatrix< BaseFloat > vis_hid_corr_
Matrix for linearity updates.
Definition: nnet-rbm.h:420
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1126
void CheckNanInf(const CuMatrixBase< Real > &mat, const char *msg="")
Check that matrix contains no nan or inf.
Definition: nnet-utils.h:132
void ReadData ( std::istream &  is,
bool  binary 
)
inlinevirtual

Reads the component content.

Reimplemented from Component.

Definition at line 182 of file nnet-rbm.h.

References RbmBase::Bernoulli, CuVectorBase< Real >::Dim(), RbmBase::Gaussian, Rbm::hid_bias_, Rbm::hid_type_, Component::input_dim_, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), Component::output_dim_, CuVector< Real >::Read(), CuMatrix< Real >::Read(), kaldi::ReadToken(), Rbm::vis_bias_, Rbm::vis_hid_, and Rbm::vis_type_.

182  {
183  std::string vis_node_type, hid_node_type;
184  ReadToken(is, binary, &vis_node_type);
185  ReadToken(is, binary, &hid_node_type);
186 
187  if (vis_node_type == "bern") {
189  } else if (vis_node_type == "gauss") {
191  }
192  if (hid_node_type == "bern") {
194  } else if (hid_node_type == "gauss") {
196  }
197 
198  vis_hid_.Read(is, binary);
199  vis_bias_.Read(is, binary);
200  hid_bias_.Read(is, binary);
201 
206  }
CuVector< BaseFloat > hid_bias_
Vector with biases.
Definition: nnet-rbm.h:418
int32 input_dim_
Data members,.
RbmNodeType hid_type_
Definition: nnet-rbm.h:425
RbmNodeType vis_type_
Definition: nnet-rbm.h:424
void Read(std::istream &is, bool binary)
I/O.
Definition: cu-vector.cc:862
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
void Read(std::istream &is, bool binary)
I/O functions.
Definition: cu-matrix.cc:459
int32 output_dim_
Dimension of the output of the Component,.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
CuVector< BaseFloat > vis_bias_
Vector with biases.
Definition: nnet-rbm.h:417
void Reconstruct ( const CuMatrixBase< BaseFloat > &  hid_state,
CuMatrix< BaseFloat > *  vis_probs 
)
inlinevirtual

Implements RbmBase.

Definition at line 239 of file nnet-rbm.h.

References CuMatrixBase< Real >::AddMatMat(), CuMatrixBase< Real >::AddVecToRows(), RbmBase::Bernoulli, Component::input_dim_, KALDI_ERR, kaldi::kNoTrans, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), Component::output_dim_, CuMatrix< Real >::Resize(), CuMatrixBase< Real >::Sigmoid(), Rbm::vis_bias_, Rbm::vis_hid_, and Rbm::vis_type_.

240  {
241  // check the dim
242  if (output_dim_ != hid_state.NumCols()) {
243  KALDI_ERR << "Nonmatching dims, component:" << output_dim_
244  << " data:" << hid_state.NumCols();
245  }
246  // optionally allocate buffer
247  if (input_dim_ != vis_probs->NumCols() ||
248  hid_state.NumRows() != vis_probs->NumRows()) {
249  vis_probs->Resize(hid_state.NumRows(), input_dim_);
250  }
251 
252  // pre-fill with bias
253  vis_probs->AddVecToRows(1.0, vis_bias_, 0.0);
254  // multiply by weights
255  vis_probs->AddMatMat(1.0, hid_state, kNoTrans, vis_hid_, kNoTrans, 1.0);
256  // optionally apply sigmoid
257  if (vis_type_ == RbmBase::Bernoulli) {
258  vis_probs->Sigmoid(*vis_probs);
259  }
260  }
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1112
int32 input_dim_
Data members,.
RbmNodeType vis_type_
Definition: nnet-rbm.h:424
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
void Sigmoid(const CuMatrixBase< Real > &src)
Set each element to the sigmoid of the corresponding element of "src": element by element...
Definition: cu-matrix.cc:1385
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:47
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
#define KALDI_ERR
Definition: kaldi-error.h:127
void AddMatMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType transA, const CuMatrixBase< Real > &B, MatrixTransposeType transB, Real beta)
C = alpha * A(^T)*B(^T) + beta * C.
Definition: cu-matrix.cc:1142
int32 output_dim_
Dimension of the output of the Component,.
CuVector< BaseFloat > vis_bias_
Vector with biases.
Definition: nnet-rbm.h:417
RbmNodeType VisType ( ) const
inlinevirtual

Implements RbmBase.

Definition at line 389 of file nnet-rbm.h.

References Rbm::vis_type_.

389  {
390  return vis_type_;
391  }
RbmNodeType vis_type_
Definition: nnet-rbm.h:424
void WriteAsNnet ( std::ostream &  os,
bool  binary 
) const
inlinevirtual

Implements RbmBase.

Definition at line 397 of file nnet-rbm.h.

References RbmBase::Bernoulli, Rbm::hid_bias_, Rbm::HidType(), Component::InputDim(), Component::kAffineTransform, Component::kSigmoid, Component::OutputDim(), Component::TypeToMarker(), Rbm::vis_hid_, CuVector< Real >::Write(), CuMatrixBase< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

397  {
398  // header,
400  WriteBasicType(os, binary, OutputDim());
401  WriteBasicType(os, binary, InputDim());
402  if (!binary) os << "\n";
403  // data,
404  vis_hid_.Write(os, binary);
405  hid_bias_.Write(os, binary);
406  // sigmoid activation,
407  if (HidType() == Bernoulli) {
409  WriteBasicType(os, binary, OutputDim());
410  WriteBasicType(os, binary, OutputDim());
411  }
412  if (!binary) os << "\n";
413  }
CuVector< BaseFloat > hid_bias_
Vector with biases.
Definition: nnet-rbm.h:418
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
int32 OutputDim() const
Get the dimension of the output,.
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
void Write(std::ostream &is, bool binary) const
Definition: cu-vector.cc:872
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
int32 InputDim() const
Get the dimension of the input,.
RbmNodeType HidType() const
Definition: nnet-rbm.h:393
void Write(std::ostream &os, bool binary) const
Definition: cu-matrix.cc:467
void WriteData ( std::ostream &  os,
bool  binary 
) const
inlinevirtual

Writes the component content.

Reimplemented from Component.

Definition at line 208 of file nnet-rbm.h.

References RbmBase::Bernoulli, RbmBase::Gaussian, Rbm::hid_bias_, Rbm::hid_type_, KALDI_ERR, Rbm::vis_bias_, Rbm::vis_hid_, Rbm::vis_type_, CuVector< Real >::Write(), CuMatrixBase< Real >::Write(), and kaldi::WriteToken().

208  {
209  switch (vis_type_) {
210  case Bernoulli : WriteToken(os,binary, "bern"); break;
211  case Gaussian : WriteToken(os,binary, "gauss"); break;
212  default : KALDI_ERR << "Unknown type " << vis_type_;
213  }
214  switch (hid_type_) {
215  case Bernoulli : WriteToken(os,binary, "bern"); break;
216  case Gaussian : WriteToken(os,binary, "gauss"); break;
217  default : KALDI_ERR << "Unknown type " << hid_type_;
218  }
219  vis_hid_.Write(os, binary);
220  vis_bias_.Write(os, binary);
221  hid_bias_.Write(os, binary);
222  }
CuVector< BaseFloat > hid_bias_
Vector with biases.
Definition: nnet-rbm.h:418
RbmNodeType hid_type_
Definition: nnet-rbm.h:425
RbmNodeType vis_type_
Definition: nnet-rbm.h:424
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
void Write(std::ostream &is, bool binary) const
Definition: cu-vector.cc:872
#define KALDI_ERR
Definition: kaldi-error.h:127
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
CuVector< BaseFloat > vis_bias_
Vector with biases.
Definition: nnet-rbm.h:417
void Write(std::ostream &os, bool binary) const
Definition: cu-matrix.cc:467

Member Data Documentation

CuVector<BaseFloat> hid_bias_
protected
CuVector<BaseFloat> hid_bias_corr_
protected

Vector for bias updates.

Definition at line 422 of file nnet-rbm.h.

Referenced by Rbm::RbmUpdate().

RbmNodeType hid_type_
protected
CuVector<BaseFloat> vis_bias_
protected

Vector with biases.

Definition at line 417 of file nnet-rbm.h.

Referenced by Rbm::InitData(), Rbm::RbmUpdate(), Rbm::ReadData(), Rbm::Reconstruct(), and Rbm::WriteData().

CuVector<BaseFloat> vis_bias_corr_
protected

Vector for bias updates.

Definition at line 421 of file nnet-rbm.h.

Referenced by Rbm::RbmUpdate().

CuMatrix<BaseFloat> vis_hid_
protected
CuMatrix<BaseFloat> vis_hid_corr_
protected

Matrix for linearity updates.

Definition at line 420 of file nnet-rbm.h.

Referenced by Rbm::RbmUpdate().

RbmNodeType vis_type_
protected

The documentation for this class was generated from the following file: