All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
NonlinearComponent Class Reference

This kind of Component is a base-class for things like sigmoid and softmax. More...

#include <nnet-component.h>

Inheritance diagram for NonlinearComponent:
Collaboration diagram for NonlinearComponent:

Public Member Functions

void Init (int32 dim)
 
 NonlinearComponent (int32 dim)
 
 NonlinearComponent ()
 
 NonlinearComponent (const NonlinearComponent &other)
 
virtual int32 InputDim () const
 Get size of input vectors. More...
 
virtual int32 OutputDim () const
 Get size of output vectors. More...
 
virtual void InitFromString (std::string args)
 We implement InitFromString at this level. More...
 
virtual void Read (std::istream &is, bool binary)
 We implement Read at this level as it just needs the Type(). More...
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
void Scale (BaseFloat scale)
 
void Add (BaseFloat alpha, const NonlinearComponent &other)
 
const CuVector< double > & ValueSum () const
 
const CuVector< double > & DerivSum () const
 
double Count () const
 
void SetDim (int32 dim)
 
- Public Member Functions inherited from Component
 Component ()
 
virtual std::string Type () const =0
 
virtual int32 Index () const
 Returns the index in the sequence of layers in the neural net; intended only to be used in debugging information. More...
 
virtual void SetIndex (int32 index)
 
virtual std::vector< int32 > Context () const
 Return a vector describing the temporal context this component requires for each frame of output, as a sorted list. More...
 
virtual void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const =0
 Perform forward pass propagation Input->Output. More...
 
void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) const
 A non-virtual propagate function that first resizes output if necessary. More...
 
virtual void Backprop (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, Component *to_update, CuMatrix< BaseFloat > *in_deriv) const =0
 Perform backward pass propagation of the derivative, and also either update the model (if to_update == this) or update another model or compute the model derivative (otherwise). More...
 
virtual bool BackpropNeedsInput () const
 
virtual bool BackpropNeedsOutput () const
 
virtual ComponentCopy () const =0
 Copy component (deep copy). More...
 
virtual std::string Info () const
 
virtual ~Component ()
 

Protected Member Functions

void UpdateStats (const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > *deriv=NULL)
 
const NonlinearComponentoperator= (const NonlinearComponent &other)
 

Protected Attributes

int32 dim_
 
CuVector< double > value_sum_
 
CuVector< double > deriv_sum_
 
double count_
 
std::mutex mutex_
 

Friends

class NormalizationComponent
 
class SigmoidComponent
 
class TanhComponent
 
class SoftmaxComponent
 
class LogSoftmaxComponent
 
class RectifiedLinearComponent
 
class SoftHingeComponent
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream. More...
 
static ComponentNewFromString (const std::string &initializer_line)
 Initialize the Component from one line that will contain first the type, e.g. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Return a new Component of the given type e.g. More...
 

Detailed Description

This kind of Component is a base-class for things like sigmoid and softmax.

Definition at line 352 of file nnet-component.h.

Constructor & Destructor Documentation

NonlinearComponent ( int32  dim)
inlineexplicit

Definition at line 355 of file nnet-component.h.

References NonlinearComponent::Init().

355 { Init(dim); }
NonlinearComponent ( )
inline

Definition at line 356 of file nnet-component.h.

356 : dim_(0) { } // e.g. prior to Read().
NonlinearComponent ( const NonlinearComponent other)
explicit

Definition at line 405 of file nnet-component.cc.

405  :
406  dim_(other.dim_), value_sum_(other.value_sum_), deriv_sum_(other.deriv_sum_),
407  count_(other.count_) { }

Member Function Documentation

void Add ( BaseFloat  alpha,
const NonlinearComponent other 
)

Definition at line 362 of file nnet-component.cc.

References CuVectorBase< Real >::AddVec(), NonlinearComponent::count_, NonlinearComponent::deriv_sum_, CuVectorBase< Real >::Dim(), CuVector< Real >::Resize(), and NonlinearComponent::value_sum_.

Referenced by Nnet::AddNnet(), Nnet::CopyStatsFrom(), and main().

362  {
363  if (value_sum_.Dim() == 0 && other.value_sum_.Dim() != 0)
364  value_sum_.Resize(other.value_sum_.Dim());
365  if (deriv_sum_.Dim() == 0 && other.deriv_sum_.Dim() != 0)
366  deriv_sum_.Resize(other.deriv_sum_.Dim());
367  if (other.value_sum_.Dim() != 0)
368  value_sum_.AddVec(alpha, other.value_sum_);
369  if (other.deriv_sum_.Dim() != 0)
370  deriv_sum_.AddVec(alpha, other.deriv_sum_);
371  count_ += alpha * other.count_;
372 }
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1126
double Count ( ) const
inline

Definition at line 379 of file nnet-component.h.

References NonlinearComponent::count_.

Referenced by NnetStats::AddStatsFromNnet(), and kaldi::nnet2::FixNnet().

379 { return count_; }
const CuVector<double>& DerivSum ( ) const
inline

Definition at line 378 of file nnet-component.h.

References NonlinearComponent::deriv_sum_.

Referenced by NnetStats::AddStatsFromNnet(), and kaldi::nnet2::FixNnet().

378 { return deriv_sum_; }
void InitFromString ( std::string  args)
virtual

We implement InitFromString at this level.

Implements Component.

Reimplemented in PowerComponent.

Definition at line 409 of file nnet-component.cc.

References NonlinearComponent::Init(), KALDI_ERR, kaldi::nnet2::ParseFromString(), and Component::Type().

Referenced by kaldi::nnet2::UnitTestSigmoidComponent().

409  {
410  std::string orig_args(args);
411  int32 dim;
412  bool ok = ParseFromString("dim", &args, &dim);
413  if (!ok || !args.empty() || dim <= 0)
414  KALDI_ERR << "Invalid initializer for layer of type "
415  << Type() << ": \"" << orig_args << "\"";
416  Init(dim);
417 }
bool ParseFromString(const std::string &name, std::string *string, int32 *param)
Functions used in Init routines.
#define KALDI_ERR
Definition: kaldi-error.h:127
virtual std::string Type() const =0
virtual int32 InputDim ( ) const
inlinevirtual

Get size of input vectors.

Implements Component.

Reimplemented in PowerComponent.

Definition at line 359 of file nnet-component.h.

References NonlinearComponent::dim_.

Referenced by kaldi::nnet2::FixNnet(), and NonlinearComponent::UpdateStats().

359 { return dim_; }
const NonlinearComponent& operator= ( const NonlinearComponent other)
protected
virtual int32 OutputDim ( ) const
inlinevirtual

Get size of output vectors.

Implements Component.

Reimplemented in PowerComponent.

Definition at line 360 of file nnet-component.h.

References NonlinearComponent::dim_.

360 { return dim_; }
void Read ( std::istream &  is,
bool  binary 
)
virtual

We implement Read at this level as it just needs the Type().

Implements Component.

Reimplemented in PowerComponent.

Definition at line 374 of file nnet-component.cc.

References NonlinearComponent::count_, NonlinearComponent::deriv_sum_, NonlinearComponent::dim_, kaldi::nnet2::ExpectOneOrTwoTokens(), kaldi::ExpectToken(), CuVector< Real >::Read(), kaldi::ReadBasicType(), Component::Type(), and NonlinearComponent::value_sum_.

374  {
375  std::ostringstream ostr_beg, ostr_end;
376  ostr_beg << "<" << Type() << ">"; // e.g. "<SigmoidComponent>"
377  ostr_end << "</" << Type() << ">"; // e.g. "</SigmoidComponent>"
378  ExpectOneOrTwoTokens(is, binary, ostr_beg.str(), "<Dim>");
379  ReadBasicType(is, binary, &dim_); // Read dimension.
380  ExpectToken(is, binary, "<ValueSum>");
381  value_sum_.Read(is, binary);
382  ExpectToken(is, binary, "<DerivSum>");
383  deriv_sum_.Read(is, binary);
384  ExpectToken(is, binary, "<Count>");
385  ReadBasicType(is, binary, &count_);
386  ExpectToken(is, binary, ostr_end.str());
387 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void Read(std::istream &is, bool binary)
I/O.
Definition: cu-vector.cc:862
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:188
virtual std::string Type() const =0
static void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
void SetDim ( int32  dim)

Definition at line 321 of file nnet-component.cc.

References NonlinearComponent::count_, NonlinearComponent::deriv_sum_, NonlinearComponent::dim_, KALDI_ASSERT, CuVector< Real >::Resize(), and NonlinearComponent::value_sum_.

321  {
322  KALDI_ASSERT(dim > 0);
323  dim_ = dim;
324  value_sum_.Resize(dim);
325  deriv_sum_.Resize(dim);
326  count_ = 0.0;
327 }
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void UpdateStats ( const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > *  deriv = NULL 
)
protected

Definition at line 329 of file nnet-component.cc.

References CuVectorBase< Real >::AddRowSumMat(), CuVectorBase< Real >::AddVec(), NonlinearComponent::count_, NonlinearComponent::deriv_sum_, CuVectorBase< Real >::Dim(), NonlinearComponent::InputDim(), KALDI_ASSERT, NonlinearComponent::mutex_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), CuVector< Real >::Resize(), CuVectorBase< Real >::SetZero(), and NonlinearComponent::value_sum_.

Referenced by SigmoidComponent::Backprop(), TanhComponent::Backprop(), RectifiedLinearComponent::Backprop(), SoftHingeComponent::Backprop(), SoftmaxComponent::Backprop(), and LogSoftmaxComponent::Backprop().

330  {
331  KALDI_ASSERT(out_value.NumCols() == InputDim());
332  // Check we have the correct dimensions.
333  if (value_sum_.Dim() != InputDim() ||
334  (deriv != NULL && deriv_sum_.Dim() != InputDim())) {
335  std::lock_guard<std::mutex> lock(mutex_);
336  if (value_sum_.Dim() != InputDim()) {
338  count_ = 0.0;
339  }
340  if (deriv != NULL && deriv_sum_.Dim() != InputDim()) {
342  count_ = 0.0;
344  }
345  }
346  count_ += out_value.NumRows();
347  CuVector<BaseFloat> temp(InputDim());
348  temp.AddRowSumMat(1.0, out_value, 0.0);
349  value_sum_.AddVec(1.0, temp);
350  if (deriv != NULL) {
351  temp.AddRowSumMat(1.0, *deriv, 0.0);
352  deriv_sum_.AddVec(1.0, temp);
353  }
354 }
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
Definition: cu-vector.cc:892
virtual int32 InputDim() const
Get size of input vectors.
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
void SetZero()
Math operations.
Definition: cu-vector.cc:988
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1126
const CuVector<double>& ValueSum ( ) const
inline

Definition at line 377 of file nnet-component.h.

References NonlinearComponent::value_sum_.

Referenced by NnetStats::AddStatsFromNnet().

377 { return value_sum_; }
void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Reimplemented in PowerComponent.

Definition at line 389 of file nnet-component.cc.

References NonlinearComponent::count_, NonlinearComponent::deriv_sum_, NonlinearComponent::dim_, Component::Type(), NonlinearComponent::value_sum_, CuVector< Real >::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

389  {
390  std::ostringstream ostr_beg, ostr_end;
391  ostr_beg << "<" << Type() << ">"; // e.g. "<SigmoidComponent>"
392  ostr_end << "</" << Type() << ">"; // e.g. "</SigmoidComponent>"
393  WriteToken(os, binary, ostr_beg.str());
394  WriteToken(os, binary, "<Dim>");
395  WriteBasicType(os, binary, dim_);
396  WriteToken(os, binary, "<ValueSum>");
397  value_sum_.Write(os, binary);
398  WriteToken(os, binary, "<DerivSum>");
399  deriv_sum_.Write(os, binary);
400  WriteToken(os, binary, "<Count>");
401  WriteBasicType(os, binary, count_);
402  WriteToken(os, binary, ostr_end.str());
403 }
void Write(std::ostream &is, bool binary) const
Definition: cu-vector.cc:872
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
virtual std::string Type() const =0
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Friends And Related Function Documentation

friend class LogSoftmaxComponent
friend

Definition at line 389 of file nnet-component.h.

friend class NormalizationComponent
friend

Definition at line 385 of file nnet-component.h.

friend class RectifiedLinearComponent
friend

Definition at line 390 of file nnet-component.h.

friend class SigmoidComponent
friend

Definition at line 386 of file nnet-component.h.

friend class SoftHingeComponent
friend

Definition at line 391 of file nnet-component.h.

friend class SoftmaxComponent
friend

Definition at line 388 of file nnet-component.h.

friend class TanhComponent
friend

Definition at line 387 of file nnet-component.h.

Member Data Documentation

std::mutex mutex_
protected

Definition at line 408 of file nnet-component.h.

Referenced by NonlinearComponent::UpdateStats().


The documentation for this class was generated from the following files: