SpliceMaxComponent Class Reference

This is as SpliceComponent but outputs the max of any of the inputs (taking the max across time). More...

#include <nnet-component.h>

Inheritance diagram for SpliceMaxComponent:
Collaboration diagram for SpliceMaxComponent:

Public Member Functions

 SpliceMaxComponent ()
 
void Init (int32 dim, std::vector< int32 > context)
 
virtual std::string Type () const
 
virtual std::string Info () const
 
virtual void InitFromString (std::string args)
 Initialize, typically from a line of a config file. More...
 
virtual int32 InputDim () const
 Get size of input vectors. More...
 
virtual int32 OutputDim () const
 Get size of output vectors. More...
 
virtual std::vector< int32Context () const
 Return a vector describing the temporal context this component requires for each frame of output, as a sorted list. More...
 
virtual void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Perform forward pass propagation Input->Output. More...
 
virtual void Backprop (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, Component *to_update, CuMatrix< BaseFloat > *in_deriv) const
 Perform backward pass propagation of the derivative, and also either update the model (if to_update == this) or update another model or compute the model derivative (otherwise). More...
 
virtual bool BackpropNeedsInput () const
 
virtual bool BackpropNeedsOutput () const
 
virtual ComponentCopy () const
 Copy component (deep copy). More...
 
virtual void Read (std::istream &is, bool binary)
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
- Public Member Functions inherited from Component
 Component ()
 
virtual int32 Index () const
 Returns the index in the sequence of layers in the neural net; intended only to be used in debugging information. More...
 
virtual void SetIndex (int32 index)
 
void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) const
 A non-virtual propagate function that first resizes output if necessary. More...
 
virtual ~Component ()
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (SpliceMaxComponent)
 

Private Attributes

int32 dim_
 
std::vector< int32context_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream. More...
 
static ComponentNewFromString (const std::string &initializer_line)
 Initialize the Component from one line that will contain first the type, e.g. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Return a new Component of the given type e.g. More...
 

Detailed Description

This is as SpliceComponent but outputs the max of any of the inputs (taking the max across time).

Definition at line 1133 of file nnet-component.h.

Constructor & Destructor Documentation

◆ SpliceMaxComponent()

SpliceMaxComponent ( )
inline

Definition at line 1135 of file nnet-component.h.

1135 { } // called only prior to Read() or Init().

Member Function Documentation

◆ Backprop()

void Backprop ( const ChunkInfo in_info,
const ChunkInfo out_info,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
Component to_update,
CuMatrix< BaseFloat > *  in_deriv 
) const
virtual

Perform backward pass propagation of the derivative, and also either update the model (if to_update == this) or update another model or compute the model derivative (otherwise).

Note: in_value and out_value are the values of the input and output of the component, and these may be dummy variables if respectively BackpropNeedsInput() or BackpropNeedsOutput() return false for that component (not all components need these).

num_chunks lets us treat the input matrix as contiguous-in-time chunks of equal size; it only matters if splicing is involved.

Implements Component.

Definition at line 2891 of file nnet-component.cc.

References ChunkInfo::Check(), ChunkInfo::CheckSize(), ChunkInfo::ChunkSize(), ChunkInfo::GetIndex(), ChunkInfo::GetOffset(), AffineComponent::InputDim(), KALDI_ASSERT, ChunkInfo::NumChunks(), ChunkInfo::NumCols(), CuMatrixBase< Real >::NumCols(), ChunkInfo::NumRows(), CuMatrixBase< Real >::NumRows(), and CuMatrix< Real >::Resize().

Referenced by kaldi::nnet2::BasicDebugTestForSpliceMax().

2897  {
2898  in_info.Check();
2899  out_info.Check();
2900  in_info.CheckSize(in_value);
2901  out_info.CheckSize(out_deriv);
2902  in_deriv->Resize(in_info.NumRows(), in_info.NumCols());
2903  KALDI_ASSERT(in_info.NumChunks() == out_info.NumChunks());
2904 
2905  int32 out_chunk_size = out_info.ChunkSize(),
2906  in_chunk_size = in_info.ChunkSize(),
2907  dim = out_deriv.NumCols();
2908 
2909  KALDI_ASSERT(dim == InputDim());
2910 
2911  for (int32 chunk = 0; chunk < in_info.NumChunks(); chunk++) {
2912  CuSubMatrix<BaseFloat> in_deriv_chunk(*in_deriv,
2913  chunk * in_chunk_size,
2914  in_chunk_size,
2915  0, dim),
2916  in_value_chunk(in_value,
2917  chunk * in_chunk_size,
2918  in_chunk_size,
2919  0, dim),
2920  out_deriv_chunk(out_deriv,
2921  chunk * out_chunk_size,
2922  out_chunk_size,
2923  0, dim);
2924  for (int32 r = 0; r < out_deriv_chunk.NumRows(); r++) {
2925  int32 out_chunk_ind = r;
2926  int32 out_chunk_offset =
2927  out_info.GetOffset(out_chunk_ind);
2928 
2929  for (int32 c = 0; c < dim; c++) {
2930  int32 in_r_max = -1;
2931  BaseFloat max_input = -std::numeric_limits<BaseFloat>::infinity();
2932  for (int32 context_ind = 0;
2933  context_ind < context_.size(); context_ind++) {
2934  int32 in_r =
2935  in_info.GetIndex(out_chunk_offset + context_[context_ind]);
2936  BaseFloat input = in_value_chunk(in_r, c);
2937  if (input > max_input) {
2938  max_input = input;
2939  in_r_max = in_r;
2940  }
2941  }
2942  KALDI_ASSERT(in_r_max != -1);
2943  (*in_deriv)(in_r_max, c) += out_deriv_chunk(r, c);
2944  }
2945  }
2946  }
2947 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
virtual int32 InputDim() const
Get size of input vectors.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ BackpropNeedsInput()

virtual bool BackpropNeedsInput ( ) const
inlinevirtual

Reimplemented from Component.

Definition at line 1156 of file nnet-component.h.

1156 { return true; }

◆ BackpropNeedsOutput()

virtual bool BackpropNeedsOutput ( ) const
inlinevirtual

Reimplemented from Component.

Definition at line 1157 of file nnet-component.h.

References kaldi::cu::Copy(), and KALDI_DISALLOW_COPY_AND_ASSIGN.

1157 { return false; }

◆ Context()

virtual std::vector<int32> Context ( ) const
inlinevirtual

Return a vector describing the temporal context this component requires for each frame of output, as a sorted list.

The default implementation returns a vector ( 0 ), but a splicing layer might return e.g. (-2, -1, 0, 1, 2), but it doesn't have to be contiguous. Note : The context needed by the entire network is a function of the contexts needed by all the components. It is required that Context().front() <= 0 and Context().back() >= 0.

Reimplemented from Component.

Definition at line 1143 of file nnet-component.h.

References Component::Propagate().

1143 { return context_; }

◆ Copy()

Component * Copy ( ) const
virtual

Copy component (deep copy).

Implements Component.

Definition at line 2949 of file nnet-component.cc.

References SpliceMaxComponent::Init().

2949  {
2951  ans->Init(dim_, context_);
2952  return ans;
2953 }

◆ Info()

std::string Info ( ) const
virtual

Reimplemented from Component.

Definition at line 2805 of file nnet-component.cc.

References Component::Info().

2805  {
2806  std::stringstream stream;
2807  std::ostringstream os;
2808  std::copy(context_.begin(), context_.end(),
2809  std::ostream_iterator<int32>(os, " "));
2810  stream << Component::Info() << ", context=" << os.str();
2811  return stream.str();
2812 }
virtual std::string Info() const

◆ Init()

void Init ( int32  dim,
std::vector< int32 context 
)

Definition at line 2814 of file nnet-component.cc.

References KALDI_ASSERT.

Referenced by kaldi::nnet2::BasicDebugTestForSpliceMax(), and SpliceMaxComponent::Copy().

2815  {
2816  dim_ = dim;
2817  context_ = context;
2818  KALDI_ASSERT(dim_ > 0 && context_.front() <= 0 && context_.back() >= 0);
2819 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ InitFromString()

void InitFromString ( std::string  args)
virtual

Initialize, typically from a line of a config file.

The "args" will contain any parameters that need to be passed to the Component, e.g. dimensions.

Implements Component.

Definition at line 2823 of file nnet-component.cc.

References rnnlm::i, AffineComponentPreconditionedOnline::Init(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet2::ParseFromString(), and AffineComponentPreconditionedOnline::Type().

2823  {
2824  std::string orig_args(args);
2825  int32 dim, left_context, right_context;
2826  std::vector <int32> context;
2827  bool dim_ok = ParseFromString("dim", &args, &dim);
2828  bool context_ok = ParseFromString("context", &args, &context);
2829  bool left_right_context_ok = ParseFromString("left-context",
2830  &args, &left_context) &&
2831  ParseFromString("right-context", &args,
2832  &right_context);
2833 
2834  if (!(dim_ok && (context_ok || left_right_context_ok)) ||
2835  !args.empty() || dim <= 0)
2836  KALDI_ERR << "Invalid initializer for layer of type "
2837  << Type() << ": \"" << orig_args << "\"";
2838  if (left_right_context_ok) {
2839  KALDI_ASSERT(context.size() == 0);
2840  for (int32 i = -1 * left_context; i <= right_context; i++)
2841  context.push_back(i);
2842  }
2843  Init(dim, context);
2844 }
virtual std::string Type() const
kaldi::int32 int32
bool ParseFromString(const std::string &name, std::string *string, int32 *param)
Functions used in Init routines.
void Init(int32 dim, std::vector< int32 > context)
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ InputDim()

virtual int32 InputDim ( ) const
inlinevirtual

Get size of input vectors.

Implements Component.

Definition at line 1141 of file nnet-component.h.

1141 { return dim_; }

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( SpliceMaxComponent  )
private

◆ OutputDim()

virtual int32 OutputDim ( ) const
inlinevirtual

Get size of output vectors.

Implements Component.

Definition at line 1142 of file nnet-component.h.

Referenced by kaldi::nnet2::BasicDebugTestForSpliceMax().

1142 { return dim_; }

◆ Propagate()

void Propagate ( const ChunkInfo in_info,
const ChunkInfo out_info,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Perform forward pass propagation Input->Output.

Each row is one frame or training example. Interpreted as "num_chunks" equally sized chunks of frames; this only matters for layers that do things like context splicing. Typically this variable will either be 1 (when we're processing a single contiguous chunk of data) or will be the same as in.NumFrames(), but other values are possible if some layers do splicing.

Implements Component.

Definition at line 2847 of file nnet-component.cc.

References ChunkInfo::Check(), ChunkInfo::CheckSize(), ChunkInfo::ChunkSize(), CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::CopyRows(), ChunkInfo::GetIndex(), ChunkInfo::GetOffset(), rnnlm::i, KALDI_ASSERT, CuMatrixBase< Real >::Max(), ChunkInfo::NumChunks(), and ChunkInfo::NumCols().

Referenced by kaldi::nnet2::BasicDebugTestForSpliceMax().

2850  {
2851  in_info.Check();
2852  out_info.Check();
2853  in_info.CheckSize(in);
2854  out_info.CheckSize(*out);
2855  KALDI_ASSERT(in_info.NumChunks() == out_info.NumChunks());
2856  int32 in_chunk_size = in_info.ChunkSize(),
2857  out_chunk_size = out_info.ChunkSize(),
2858  dim = in_info.NumCols();
2859 
2860  CuMatrix<BaseFloat> input_chunk_part(out_chunk_size, dim);
2861  for (int32 chunk = 0; chunk < in_info.NumChunks(); chunk++) {
2862  CuSubMatrix<BaseFloat> input_chunk(in,
2863  chunk * in_chunk_size, in_chunk_size,
2864  0, dim),
2865  output_chunk(*out,
2866  chunk * out_chunk_size,
2867  out_chunk_size, 0, dim);
2868  for (int32 offset = 0; offset < context_.size(); offset++) {
2869  // computing the indices to copy into input_chunk_part from input_chunk
2870  // copy the rows of the input matrix which correspond to the current
2871  // context index
2872  std::vector<int32> input_chunk_inds(out_chunk_size);
2873  for (int32 i = 0; i < out_chunk_size; i++) {
2874  int32 out_chunk_ind = i;
2875  int32 out_chunk_offset =
2876  out_info.GetOffset(out_chunk_ind);
2877  input_chunk_inds[i] =
2878  in_info.GetIndex(out_chunk_offset + context_[offset]);
2879  }
2880  CuArray<int32> cu_chunk_inds(input_chunk_inds);
2881  input_chunk_part.CopyRows(input_chunk, cu_chunk_inds);
2882  if (offset == 0) {
2883  output_chunk.CopyFromMat(input_chunk_part);
2884  } else {
2885  output_chunk.Max(input_chunk_part);
2886  }
2887  }
2888  }
2889 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)
virtual

Implements Component.

Definition at line 2955 of file nnet-component.cc.

References kaldi::nnet2::ExpectOneOrTwoTokens(), kaldi::ExpectToken(), rnnlm::i, KALDI_ERR, kaldi::ReadBasicType(), kaldi::ReadIntegerVector(), and kaldi::ReadToken().

2955  {
2956  ExpectOneOrTwoTokens(is, binary, "<SpliceMaxComponent>", "<Dim>");
2957  ReadBasicType(is, binary, &dim_);
2958  std::string token;
2959  ReadToken(is, false, &token);
2960  if (token == "<LeftContext>") {
2961  int32 left_context = 0, right_context = 0;
2962  std::vector<int32> context;
2963  ReadBasicType(is, binary, &left_context);
2964  ExpectToken(is, binary, "<RightContext>");
2965  ReadBasicType(is, binary, &right_context);
2966  for (int32 i = -1 * left_context; i <= right_context; i++)
2967  context.push_back(i);
2968  context_ = context;
2969  } else if (token == "<Context>") {
2970  ReadIntegerVector(is, binary, &context_);
2971  } else {
2972  KALDI_ERR << "Unknown token" << token << ", the model might be corrupted";
2973  }
2974  ExpectToken(is, binary, "</SpliceMaxComponent>");
2975 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
#define KALDI_ERR
Definition: kaldi-error.h:147
static void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)

◆ Type()

virtual std::string Type ( ) const
inlinevirtual

Implements Component.

Definition at line 1138 of file nnet-component.h.

1138 { return "SpliceMaxComponent"; }

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 2977 of file nnet-component.cc.

References kaldi::WriteBasicType(), kaldi::WriteIntegerVector(), and kaldi::WriteToken().

2977  {
2978  WriteToken(os, binary, "<SpliceMaxComponent>");
2979  WriteToken(os, binary, "<Dim>");
2980  WriteBasicType(os, binary, dim_);
2981  WriteToken(os, binary, "<Context>");
2982  WriteIntegerVector(os, binary, context_);
2983  WriteToken(os, binary, "</SpliceMaxComponent>");
2984 }
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ context_

std::vector<int32> context_
private

Definition at line 1164 of file nnet-component.h.

◆ dim_

int32 dim_
private

Definition at line 1163 of file nnet-component.h.


The documentation for this class was generated from the following files: