nnet-rbm.h
Go to the documentation of this file.
1 // nnet/nnet-rbm.h
2 
3 // Copyright 2012-2013 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #ifndef KALDI_NNET_NNET_RBM_H_
22 #define KALDI_NNET_NNET_RBM_H_
23 
24 #include <string>
25 
26 #include "nnet/nnet-component.h"
27 #include "nnet/nnet-nnet.h"
28 #include "nnet/nnet-utils.h"
29 #include "nnet/nnet-various.h"
30 #include "cudamatrix/cu-math.h"
31 
32 namespace kaldi {
33 namespace nnet1 {
34 
35 class RbmBase : public Component {
36  public:
37  typedef enum {
40  } RbmNodeType;
41 
42  RbmBase(int32 dim_in, int32 dim_out):
43  Component(dim_in, dim_out)
44  { }
45 
46  // Inherited from Component::
47  // void Propagate(...)
48  // virtual void PropagateFnc(...) = 0
49 
50  virtual void Reconstruct(
51  const CuMatrixBase<BaseFloat> &hid_state,
52  CuMatrix<BaseFloat> *vis_probs
53  ) = 0;
54  virtual void RbmUpdate(
55  const CuMatrixBase<BaseFloat> &pos_vis,
56  const CuMatrixBase<BaseFloat> &pos_hid,
57  const CuMatrixBase<BaseFloat> &neg_vis,
58  const CuMatrixBase<BaseFloat> &neg_hid
59  ) = 0;
60 
61  virtual RbmNodeType VisType() const = 0;
62  virtual RbmNodeType HidType() const = 0;
63 
64  virtual void WriteAsNnet(std::ostream& os, bool binary) const = 0;
65 
67  void SetRbmTrainOptions(const RbmTrainOptions& opts) {
68  rbm_opts_ = opts;
69  }
72  return rbm_opts_;
73  }
74 
75  protected:
77 
78  private:
80  // as for RBMs we use Reconstruct(.)
82  const CuMatrixBase<BaseFloat> &out,
83  const CuMatrixBase<BaseFloat> &out_diff,
84  CuMatrix<BaseFloat> *in_diff)
85  { }
87  const CuMatrixBase<BaseFloat> &out,
88  const CuMatrixBase<BaseFloat> &out_diff,
89  CuMatrixBase<BaseFloat> *in_diff)
90  { }
92 };
93 
94 
95 
96 class Rbm : public RbmBase {
97  public:
98  Rbm(int32 dim_in, int32 dim_out):
99  RbmBase(dim_in, dim_out)
100  { }
101 
103  { }
104 
105  Component* Copy() const {
106  return new Rbm(*this);
107  }
108 
110  return kRbm;
111  }
112 
113  void InitData(std::istream &is) {
114  // define options,
115  std::string vis_type;
116  std::string hid_type;
117  float vis_bias_mean = 0.0, vis_bias_range = 0.0,
118  hid_bias_mean = 0.0, hid_bias_range = 0.0,
119  param_stddev = 0.1;
120  std::string vis_bias_cmvn_file; // initialize biases to logit(p_active)
121  // parse config,
122  std::string token;
123  while (is >> std::ws, !is.eof()) {
124  ReadToken(is, false, &token);
125  if (token == "<VisibleType>") ReadToken(is, false, &vis_type);
126  else if (token == "<HiddenType>") ReadToken(is, false, &hid_type);
127  else if (token == "<VisibleBiasMean>") ReadBasicType(is, false, &vis_bias_mean);
128  else if (token == "<VisibleBiasRange>") ReadBasicType(is, false, &vis_bias_range);
129  else if (token == "<HiddenBiasMean>") ReadBasicType(is, false, &hid_bias_mean);
130  else if (token == "<HiddenBiasRange>") ReadBasicType(is, false, &hid_bias_range);
131  else if (token == "<ParamStddev>") ReadBasicType(is, false, &param_stddev);
132  else if (token == "<VisibleBiasCmvnFilename>") ReadToken(is, false, &vis_bias_cmvn_file);
133  else KALDI_ERR << "Unknown token " << token << " Typo in config?";
134  }
135 
136  // Translate the 'node' types,
137  if (vis_type == "bern" || vis_type == "Bernoulli") vis_type_ = RbmBase::Bernoulli;
138  else if (vis_type == "gauss" || vis_type == "Gaussian") vis_type_ = RbmBase::Gaussian;
139  else KALDI_ERR << "Wrong <VisibleType>" << vis_type;
140  //
141  if (hid_type == "bern" || hid_type == "Bernoulli") hid_type_ = RbmBase::Bernoulli;
142  else if (hid_type == "gauss" || hid_type == "Gaussian") hid_type_ = RbmBase::Gaussian;
143  else KALDI_ERR << "Wrong <HiddenType>" << hid_type;
144 
145  //
146  // Initialize trainable parameters,
147  //
148  // visible-hidden connections,
149  vis_hid_.Resize(OutputDim(), InputDim());
150  RandGauss(0.0, param_stddev, &vis_hid_);
151  // hidden-bias,
152  hid_bias_.Resize(OutputDim());
153  RandUniform(hid_bias_mean, hid_bias_range, &hid_bias_);
154  // visible-bias,
155  if (vis_bias_cmvn_file == "") {
156  vis_bias_.Resize(InputDim());
157  RandUniform(vis_bias_mean, vis_bias_range, &vis_bias_);
158  } else {
159  KALDI_LOG << "Initializing from <VisibleBiasCmvnFilename> "
160  << vis_bias_cmvn_file;
161  // Reading Nnet with 'global-cmvn' components,
162  Nnet cmvn;
163  cmvn.Read(vis_bias_cmvn_file);
164  KALDI_ASSERT(InputDim() == cmvn.InputDim());
165  // The parameters from <AddShift> correspond to 'negative' mean values,
166  Vector<BaseFloat> p(cmvn.InputDim());
167  dynamic_cast<AddShift&>(cmvn.GetComponent(0)).GetParams(&p);
168  p.Scale(-1.0); // 'un-do' negation of mean values,
169  p.ApplyFloor(0.0001);
170  p.ApplyCeiling(0.9999);
171  // Getting the logit,
172  Vector<BaseFloat> logit_p(p.Dim());
173  for (int32 d = 0; d < p.Dim(); d++) {
174  logit_p(d) = Log(p(d)) - Log(1.0 - p(d));
175  }
176  vis_bias_ = logit_p;
177  KALDI_ASSERT(vis_bias_.Dim() == InputDim());
178  }
179  }
180 
181 
182  void ReadData(std::istream &is, bool binary) {
183  std::string vis_node_type, hid_node_type;
184  ReadToken(is, binary, &vis_node_type);
185  ReadToken(is, binary, &hid_node_type);
186 
187  if (vis_node_type == "bern") {
188  vis_type_ = RbmBase::Bernoulli;
189  } else if (vis_node_type == "gauss") {
190  vis_type_ = RbmBase::Gaussian;
191  }
192  if (hid_node_type == "bern") {
193  hid_type_ = RbmBase::Bernoulli;
194  } else if (hid_node_type == "gauss") {
195  hid_type_ = RbmBase::Gaussian;
196  }
197 
198  vis_hid_.Read(is, binary);
199  vis_bias_.Read(is, binary);
200  hid_bias_.Read(is, binary);
201 
202  KALDI_ASSERT(vis_hid_.NumRows() == output_dim_);
203  KALDI_ASSERT(vis_hid_.NumCols() == input_dim_);
204  KALDI_ASSERT(vis_bias_.Dim() == input_dim_);
205  KALDI_ASSERT(hid_bias_.Dim() == output_dim_);
206  }
207 
208  void WriteData(std::ostream &os, bool binary) const {
209  switch (vis_type_) {
210  case Bernoulli : WriteToken(os,binary, "bern"); break;
211  case Gaussian : WriteToken(os,binary, "gauss"); break;
212  default : KALDI_ERR << "Unknown type " << vis_type_;
213  }
214  switch (hid_type_) {
215  case Bernoulli : WriteToken(os,binary, "bern"); break;
216  case Gaussian : WriteToken(os,binary, "gauss"); break;
217  default : KALDI_ERR << "Unknown type " << hid_type_;
218  }
219  vis_hid_.Write(os, binary);
220  vis_bias_.Write(os, binary);
221  hid_bias_.Write(os, binary);
222  }
223 
224 
225  // Component API
228  // pre-fill with bias
229  out->AddVecToRows(1.0, hid_bias_, 0.0);
230  // multiply by weights^t
231  out->AddMatMat(1.0, in, kNoTrans, vis_hid_, kTrans, 1.0);
232  // optionally apply sigmoid
233  if (hid_type_ == RbmBase::Bernoulli) {
234  out->Sigmoid(*out);
235  }
236  }
237 
238  // RBM training API
239  void Reconstruct(const CuMatrixBase<BaseFloat> &hid_state,
240  CuMatrix<BaseFloat> *vis_probs) {
241  // check the dim
242  if (output_dim_ != hid_state.NumCols()) {
243  KALDI_ERR << "Nonmatching dims, component:" << output_dim_
244  << " data:" << hid_state.NumCols();
245  }
246  // optionally allocate buffer
247  if (input_dim_ != vis_probs->NumCols() ||
248  hid_state.NumRows() != vis_probs->NumRows()) {
249  vis_probs->Resize(hid_state.NumRows(), input_dim_);
250  }
251 
252  // pre-fill with bias
253  vis_probs->AddVecToRows(1.0, vis_bias_, 0.0);
254  // multiply by weights
255  vis_probs->AddMatMat(1.0, hid_state, kNoTrans, vis_hid_, kNoTrans, 1.0);
256  // optionally apply sigmoid
257  if (vis_type_ == RbmBase::Bernoulli) {
258  vis_probs->Sigmoid(*vis_probs);
259  }
260  }
261 
262  void RbmUpdate(const CuMatrixBase<BaseFloat> &pos_vis,
263  const CuMatrixBase<BaseFloat> &pos_hid,
264  const CuMatrixBase<BaseFloat> &neg_vis,
265  const CuMatrixBase<BaseFloat> &neg_hid) {
266  // dims
267  KALDI_ASSERT(pos_vis.NumRows() == pos_hid.NumRows() &&
268  pos_vis.NumRows() == neg_vis.NumRows() &&
269  pos_vis.NumRows() == neg_hid.NumRows() &&
270  pos_vis.NumCols() == neg_vis.NumCols() &&
271  pos_hid.NumCols() == neg_hid.NumCols() &&
272  pos_vis.NumCols() == input_dim_ &&
273  pos_hid.NumCols() == output_dim_);
274 
275  // lazy initialization of buffers
276  if ( vis_hid_corr_.NumRows() != vis_hid_.NumRows() ||
277  vis_hid_corr_.NumCols() != vis_hid_.NumCols() ||
278  vis_bias_corr_.Dim() != vis_bias_.Dim() ||
279  hid_bias_corr_.Dim() != hid_bias_.Dim() ) {
280  vis_hid_corr_.Resize(vis_hid_.NumRows(), vis_hid_.NumCols(), kSetZero);
281  vis_bias_corr_.Resize(vis_bias_.Dim(), kSetZero);
282  hid_bias_corr_.Resize(hid_bias_.Dim(), kSetZero);
283  }
284 
285  // ANTI-WEIGHT-EXPLOSION PROTECTION (Gaussian-Bernoulli RBM)
286  //
287  // in the following section we detect that the weights in
288  // Gaussian-Bernoulli RBM are almost exploding. The weight
289  // explosion is caused by large variance of the reconstructed data,
290  // which causes a feed-back loop that keeps increasing the weights.
291  //
292  // To avoid explosion, the standard-deviation of the visible-data
293  // and reconstructed-data should be about the same.
294  // The model is particularly sensitive at the very
295  // beginning of the CD-1 training.
296  //
297  // We compute the standard deviations on
298  // * 'A' : input mini-batch
299  // * 'B' : reconstruction.
300  // When 'B > 2*A', we stabilize the training in this way:
301  // 1. we scale down the weights and biases by 'A/B',
302  // 2. we shrink learning rate by 0.9x,
303  // 3. we reset the momentum buffer,
304  //
305  // A warning message is put to the log. In later stage
306  // the learning-rate returns back to its original value.
307  //
308  // To avoid the issue, we make sure that the weight-matrix
309  // is sensibly initialized.
310  //
311  if (vis_type_ == RbmBase::Gaussian) {
312  // check the data have no nan/inf:
313  CheckNanInf(pos_vis, "pos_vis");
314  CheckNanInf(pos_hid, "pos_hid");
315  CheckNanInf(neg_vis, "neg_vis");
316  CheckNanInf(neg_hid, "pos_hid");
317 
318  // get standard deviations of pos_vis and neg_vis:
319  BaseFloat pos_vis_std = ComputeStdDev(pos_vis);
320  BaseFloat neg_vis_std = ComputeStdDev(neg_vis);
321 
322  // monitor the standard deviation mismatch : data vs. reconstruction
323  if (pos_vis_std * 2 < neg_vis_std) {
324  // 1) scale-down the weights and biases
325  BaseFloat scale = pos_vis_std / neg_vis_std;
326  vis_hid_.Scale(scale);
327  vis_bias_.Scale(scale);
328  hid_bias_.Scale(scale);
329  // 2) reduce the learning rate
330  rbm_opts_.learn_rate *= 0.9;
331  // 3) reset the momentum buffers
332  vis_hid_corr_.SetZero();
333  vis_bias_corr_.SetZero();
334  hid_bias_corr_.SetZero();
335 
336  KALDI_WARN << "Mismatch between pos_vis and neg_vis variances, "
337  << "danger of weight explosion."
338  << " a) Reducing weights with scale " << scale
339  << " b) Lowering learning rate to " << rbm_opts_.learn_rate
340  << " [pos_vis_std:" << pos_vis_std
341  << ",neg_vis_std:" << neg_vis_std << "]";
342  return; /* i.e. don't update now, the update would be too BIG */
343  }
344  }
345  //
346  // End of weight-explosion check
347 
348 
349  // We use these training hyper-parameters
350  //
351  const BaseFloat lr = rbm_opts_.learn_rate;
352  const BaseFloat mmt = rbm_opts_.momentum;
353  const BaseFloat l2 = rbm_opts_.l2_penalty;
354 
355  // UPDATE vishid matrix
356  //
357  // vishidinc = momentum*vishidinc + ...
358  // epsilonw*( (posprods-negprods)/numcases - weightcost*vishid)
359  //
360  // vishidinc[t] = -(epsilonw/numcases)*negprods + momentum*vishidinc[t-1]
361  // +(epsilonw/numcases)*posprods
362  // -(epsilonw*weightcost)*vishid[t-1]
363  //
364  BaseFloat N = static_cast<BaseFloat>(pos_vis.NumRows());
365  vis_hid_corr_.AddMatMat(-lr/N, neg_hid, kTrans, neg_vis, kNoTrans, mmt);
366  vis_hid_corr_.AddMatMat(+lr/N, pos_hid, kTrans, pos_vis, kNoTrans, 1.0);
367  vis_hid_corr_.AddMat(-lr*l2, vis_hid_);
368  vis_hid_.AddMat(1.0, vis_hid_corr_);
369 
370  // UPDATE visbias vector
371  //
372  // visbiasinc = momentum*visbiasinc +
373  // (epsilonvb/numcases)*(posvisact-negvisact);
374  //
375  vis_bias_corr_.AddRowSumMat(-lr/N, neg_vis, mmt);
376  vis_bias_corr_.AddRowSumMat(+lr/N, pos_vis, 1.0);
377  vis_bias_.AddVec(1.0, vis_bias_corr_, 1.0);
378 
379  // UPDATE hidbias vector
380  //
381  // hidbiasinc = momentum*hidbiasinc +
382  // (epsilonhb/numcases)*(poshidact-neghidact);
383  //
384  hid_bias_corr_.AddRowSumMat(-lr/N, neg_hid, mmt);
385  hid_bias_corr_.AddRowSumMat(+lr/N, pos_hid, 1.0);
386  hid_bias_.AddVec(1.0, hid_bias_corr_, 1.0);
387  }
388 
390  return vis_type_;
391  }
392 
394  return hid_type_;
395  }
396 
397  void WriteAsNnet(std::ostream& os, bool binary) const {
398  // header,
400  WriteBasicType(os, binary, OutputDim());
401  WriteBasicType(os, binary, InputDim());
402  if (!binary) os << "\n";
403  // data,
404  vis_hid_.Write(os, binary);
405  hid_bias_.Write(os, binary);
406  // sigmoid activation,
407  if (HidType() == Bernoulli) {
409  WriteBasicType(os, binary, OutputDim());
410  WriteBasicType(os, binary, OutputDim());
411  }
412  if (!binary) os << "\n";
413  }
414 
415  protected:
419 
423 
426 };
427 
428 
429 
430 } // namespace nnet1
431 } // namespace kaldi
432 
433 #endif // KALDI_NNET_NNET_RBM_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void InitData(std::istream &is)
Virtual interface for initialization and I/O,.
Definition: nnet-rbm.h:113
void WriteAsNnet(std::ostream &os, bool binary) const
Definition: nnet-rbm.h:397
CuVector< BaseFloat > hid_bias_
Vector with biases.
Definition: nnet-rbm.h:418
Component * Copy() const
Copy component (deep copy),.
Definition: nnet-rbm.h:105
void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
Abstract interface for propagation/backpropagation.
Definition: nnet-rbm.h:226
CuVector< BaseFloat > hid_bias_corr_
Vector for bias updates.
Definition: nnet-rbm.h:422
int32 input_dim_
Data members,.
RbmNodeType hid_type_
Definition: nnet-rbm.h:425
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
RbmTrainOptions rbm_opts_
Definition: nnet-rbm.h:76
RbmNodeType vis_type_
Definition: nnet-rbm.h:424
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
CuMatrix< BaseFloat > vis_hid_
Matrix with neuron weights.
Definition: nnet-rbm.h:416
RbmNodeType VisType() const
Definition: nnet-rbm.h:389
void RbmUpdate(const CuMatrixBase< BaseFloat > &pos_vis, const CuMatrixBase< BaseFloat > &pos_hid, const CuMatrixBase< BaseFloat > &neg_vis, const CuMatrixBase< BaseFloat > &neg_hid)
Definition: nnet-rbm.h:262
void RandUniform(BaseFloat mu, BaseFloat range, CuMatrixBase< Real > *mat, struct RandomState *state=NULL)
Fill CuMatrix with random numbers (Uniform distribution): mu = the mean value, range = the &#39;width&#39; of...
Definition: nnet-utils.h:188
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
Backward pass transformation (to be implemented by descending class...)
Definition: nnet-rbm.h:86
ComponentType
Component type identification mechanism,.
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
void Reconstruct(const CuMatrixBase< BaseFloat > &hid_state, CuMatrix< BaseFloat > *vis_probs)
Definition: nnet-rbm.h:239
Adds shift to all the lines of the matrix (can be used for global mean normalization) ...
Definition: nnet-various.h:291
Real ComputeStdDev(const CuMatrixBase< Real > &mat)
Get the standard deviation of values in the matrix.
Definition: nnet-utils.h:142
ComponentType GetType() const
Get Type Identification of the component,.
Definition: nnet-rbm.h:109
void AddVecToRows(Real alpha, const CuVectorBase< Real > &row, Real beta=1.0)
(for each row r of *this), r = alpha * row + beta * r
Definition: cu-matrix.cc:1261
double Log(double x)
Definition: kaldi-math.h:100
void Sigmoid(const CuMatrixBase< Real > &src)
Set each element to the sigmoid of the corresponding element of "src": element by element...
Definition: cu-matrix.cc:1534
RbmBase(int32 dim_in, int32 dim_out)
Definition: nnet-rbm.h:42
virtual RbmNodeType VisType() const =0
void WriteData(std::ostream &os, bool binary) const
Writes the component content.
Definition: nnet-rbm.h:208
RbmNodeType HidType() const
Definition: nnet-rbm.h:393
virtual RbmNodeType HidType() const =0
int32 InputDim() const
Get the dimension of the input,.
#define KALDI_ERR
Definition: kaldi-error.h:147
void RandGauss(BaseFloat mu, BaseFloat sigma, CuMatrixBase< Real > *mat, struct RandomState *state=NULL)
Fill CuMatrix with random numbers (Gaussian distribution): mu = the mean value, sigma = standard devi...
Definition: nnet-utils.h:164
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
void AddMatMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType transA, const CuMatrixBase< Real > &B, MatrixTransposeType transB, Real beta)
C = alpha * A(^T)*B(^T) + beta * C.
Definition: cu-matrix.cc:1291
#define KALDI_WARN
Definition: kaldi-error.h:150
const RbmTrainOptions & GetRbmTrainOptions() const
Get training hyper-parameters from the network.
Definition: nnet-rbm.h:71
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
virtual void RbmUpdate(const CuMatrixBase< BaseFloat > &pos_vis, const CuMatrixBase< BaseFloat > &pos_hid, const CuMatrixBase< BaseFloat > &neg_vis, const CuMatrixBase< BaseFloat > &neg_hid)=0
virtual void WriteAsNnet(std::ostream &os, bool binary) const =0
int32 output_dim_
Dimension of the output of the Component,.
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
void Backpropagate(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Definition: nnet-rbm.h:81
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
Rbm(int32 dim_in, int32 dim_out)
Definition: nnet-rbm.h:98
CuVector< BaseFloat > vis_bias_
Vector with biases.
Definition: nnet-rbm.h:417
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153
CuVector< BaseFloat > vis_bias_corr_
Vector for bias updates.
Definition: nnet-rbm.h:421
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Abstract class, building block of the network.
void SetRbmTrainOptions(const RbmTrainOptions &opts)
Set training hyper-parameters to the network and its UpdatableComponent(s)
Definition: nnet-rbm.h:67
CuMatrix< BaseFloat > vis_hid_corr_
Matrix for linearity updates.
Definition: nnet-rbm.h:420
int32 OutputDim() const
Get the dimension of the output,.
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
#define KALDI_LOG
Definition: kaldi-error.h:153
virtual void Reconstruct(const CuMatrixBase< BaseFloat > &hid_state, CuMatrix< BaseFloat > *vis_probs)=0
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50
void CheckNanInf(const CuMatrixBase< Real > &mat, const char *msg="")
Check that matrix contains no nan or inf.
Definition: nnet-utils.h:132
void ReadData(std::istream &is, bool binary)
Reads the component content.
Definition: nnet-rbm.h:182