SpliceComponent Class Reference

Splices a context window of frames together [over time]. More...

#include <nnet-component.h>

Inheritance diagram for SpliceComponent:
Collaboration diagram for SpliceComponent:

Public Member Functions

 SpliceComponent ()
 
void Init (int32 input_dim, std::vector< int32 > context, int32 const_component_dim=0)
 
virtual std::string Type () const
 
virtual std::string Info () const
 
virtual void InitFromString (std::string args)
 Initialize, typically from a line of a config file. More...
 
virtual int32 InputDim () const
 Get size of input vectors. More...
 
virtual int32 OutputDim () const
 Get size of output vectors. More...
 
virtual std::vector< int32Context () const
 Return a vector describing the temporal context this component requires for each frame of output, as a sorted list. More...
 
virtual void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
 Perform forward pass propagation Input->Output. More...
 
virtual void Backprop (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, Component *to_update, CuMatrix< BaseFloat > *in_deriv) const
 Perform backward pass propagation of the derivative, and also either update the model (if to_update == this) or update another model or compute the model derivative (otherwise). More...
 
virtual bool BackpropNeedsInput () const
 
virtual bool BackpropNeedsOutput () const
 
virtual ComponentCopy () const
 Copy component (deep copy). More...
 
virtual void Read (std::istream &is, bool binary)
 
virtual void Write (std::ostream &os, bool binary) const
 Write component to stream. More...
 
- Public Member Functions inherited from Component
 Component ()
 
virtual int32 Index () const
 Returns the index in the sequence of layers in the neural net; intended only to be used in debugging information. More...
 
virtual void SetIndex (int32 index)
 
void Propagate (const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out) const
 A non-virtual propagate function that first resizes output if necessary. More...
 
virtual ~Component ()
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (SpliceComponent)
 

Private Attributes

int32 input_dim_
 
std::vector< int32context_
 
int32 const_component_dim_
 

Additional Inherited Members

- Static Public Member Functions inherited from Component
static ComponentReadNew (std::istream &is, bool binary)
 Read component from stream. More...
 
static ComponentNewFromString (const std::string &initializer_line)
 Initialize the Component from one line that will contain first the type, e.g. More...
 
static ComponentNewComponentOfType (const std::string &type)
 Return a new Component of the given type e.g. More...
 

Detailed Description

Splices a context window of frames together [over time].

Definition at line 1092 of file nnet-component.h.

Constructor & Destructor Documentation

◆ SpliceComponent()

SpliceComponent ( )
inline

Definition at line 1094 of file nnet-component.h.

1094 { } // called only prior to Read() or Init().

Member Function Documentation

◆ Backprop()

void Backprop ( const ChunkInfo in_info,
const ChunkInfo out_info,
const CuMatrixBase< BaseFloat > &  in_value,
const CuMatrixBase< BaseFloat > &  out_value,
const CuMatrixBase< BaseFloat > &  out_deriv,
Component to_update,
CuMatrix< BaseFloat > *  in_deriv 
) const
virtual

Perform backward pass propagation of the derivative, and also either update the model (if to_update == this) or update another model or compute the model derivative (otherwise).

Note: in_value and out_value are the values of the input and output of the component, and these may be dummy variables if respectively BackpropNeedsInput() or BackpropNeedsOutput() return false for that component (not all components need these).

num_chunks lets us treat the input matrix as contiguous-in-time chunks of equal size; it only matters if splicing is involved.

Implements Component.

Definition at line 2662 of file nnet-component.cc.

References ChunkInfo::Check(), ChunkInfo::CheckSize(), ChunkInfo::ChunkSize(), CuMatrixBase< Real >::CopyRows(), ChunkInfo::GetIndex(), ChunkInfo::GetOffset(), AffineComponent::InputDim(), KALDI_ASSERT, kaldi::kUndefined, ChunkInfo::NumChunks(), ChunkInfo::NumCols(), CuMatrixBase< Real >::NumCols(), ChunkInfo::NumRows(), CuMatrixBase< Real >::NumRows(), AffineComponent::OutputDim(), and CuMatrix< Real >::Resize().

2668  {
2669  in_info.Check();
2670  out_info.Check();
2671  out_info.CheckSize(out_deriv);
2672  in_deriv->Resize(in_info.NumRows(), in_info.NumCols(), kUndefined);
2673  KALDI_ASSERT(in_info.NumChunks() == out_info.NumChunks());
2674  int32 num_chunks = in_info.NumChunks();
2675  // rewrite backpropagate
2676 
2677  int32 out_chunk_size = out_info.ChunkSize(),
2678  in_chunk_size = in_info.ChunkSize(),
2679  output_dim = out_deriv.NumCols(),
2680  input_dim = InputDim();
2681 
2682  KALDI_ASSERT(OutputDim() == output_dim);
2683 
2684  int32 num_splice = context_.size(),
2685  const_dim = const_component_dim_;
2686  // 'indexes' is, for each index from 0 to num_splice - 1,
2687  // then for each row of "in_deriv", the corresponding row of "out_deriv" that
2688  // we add, or -1 if.
2689 
2690  std::vector<std::vector<int32> > indexes(num_splice);
2691  // const_dim != 0, "const_indexes" will be used to determine which
2692  // row of "in" we copy the last part of each row of "out" from (this part is
2693  // not subject to splicing, it's assumed constant for each frame of "input".
2694  std::vector<int32> const_indexes(const_dim == 0 ? 0 : in_deriv->NumRows(), -1);
2695 
2696  for (int32 c = 0; c < indexes.size(); c++)
2697  indexes[c].resize(in_deriv->NumRows(), -1); // set to -1 by default,
2698  // this gets interpreted by the CopyRows() code
2699  // as a signal to zero the output...
2700 
2701  int32 dim = input_dim - const_dim; // dimension we are splicing
2702  for (int32 chunk = 0; chunk < num_chunks; chunk++) {
2703  if (chunk == 0) { // this branch can be taken for all chunks, but is not
2704  // taken for efficiency reasons
2705  for (int32 c = 0; c < num_splice; c++) {
2706  for (int32 out_index = 0; out_index < out_chunk_size; out_index++) {
2707  int32 out_offset = out_info.GetOffset(out_index);
2708  int32 in_index = in_info.GetIndex(out_offset + context_[c]);
2709  indexes[c][chunk * in_chunk_size + in_index] =
2710  chunk * out_chunk_size + out_index;
2711  }
2712  }
2713  } else { // just copy the indexes from the previous chunk
2714  for (int32 c = 0; c < num_splice; c++) {
2715  for (int32 in_index = 0; in_index < in_chunk_size; in_index++) {
2716  int32 last_value = indexes[c][(chunk-1) * in_chunk_size + in_index];
2717  indexes[c][chunk * in_chunk_size + in_index] =
2718  (last_value == -1 ? -1 : last_value + out_chunk_size);
2719  }
2720  }
2721  }
2722  // this code corresponds to the way the forward propagation works; see
2723  // comments there.
2724  if (const_dim != 0) {
2725  for (int32 out_index = 0; out_index < out_chunk_size; out_index++) {
2726  const_indexes[chunk * in_chunk_size + out_index] =
2727  chunk * out_chunk_size + out_index;
2728  }
2729  }
2730  }
2731 
2732  CuMatrix<BaseFloat> temp_mat(in_deriv->NumRows(), dim, kUndefined);
2733 
2734  for (int32 c = 0; c < num_splice; c++) {
2735  CuArray<int32> cu_indexes(indexes[c]);
2736  int32 dim = input_dim - const_dim; // dimension we
2737  // are splicing
2738  CuSubMatrix<BaseFloat> out_deriv_part(out_deriv, 0, out_deriv.NumRows(),
2739  c * dim, dim),
2740  in_deriv_part(*in_deriv, 0, in_deriv->NumRows(),
2741  0, dim);
2742  if (c == 0) {
2743  in_deriv_part.CopyRows(out_deriv_part, cu_indexes);
2744  } else {
2745  temp_mat.CopyRows(out_deriv_part, cu_indexes);
2746  in_deriv_part.AddMat(1.0, temp_mat);
2747  }
2748  }
2749  if (const_dim != 0) {
2750  CuSubMatrix<BaseFloat> out_deriv_part(out_deriv, 0, out_deriv.NumRows(),
2751  out_deriv.NumCols() - const_dim,
2752  const_dim),
2753  in_deriv_part(*in_deriv, 0, in_deriv->NumRows(),
2754  in_deriv->NumCols() - const_dim, const_dim);
2755  CuArray<int32> cu_const_indexes(const_indexes);
2756  in_deriv_part.CopyRows(out_deriv_part, cu_const_indexes);
2757  }
2758 }
std::vector< int32 > context_
kaldi::int32 int32
virtual int32 InputDim() const
Get size of input vectors.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual int32 OutputDim() const
Get size of output vectors.

◆ BackpropNeedsInput()

virtual bool BackpropNeedsInput ( ) const
inlinevirtual

Reimplemented from Component.

Definition at line 1119 of file nnet-component.h.

1119 { return false; }

◆ BackpropNeedsOutput()

virtual bool BackpropNeedsOutput ( ) const
inlinevirtual

Reimplemented from Component.

Definition at line 1120 of file nnet-component.h.

References kaldi::cu::Copy(), and KALDI_DISALLOW_COPY_AND_ASSIGN.

1120 { return false; }

◆ Context()

virtual std::vector<int32> Context ( ) const
inlinevirtual

Return a vector describing the temporal context this component requires for each frame of output, as a sorted list.

The default implementation returns a vector ( 0 ), but a splicing layer might return e.g. (-2, -1, 0, 1, 2), but it doesn't have to be contiguous. Note : The context needed by the entire network is a function of the contexts needed by all the components. It is required that Context().front() <= 0 and Context().back() >= 0.

Reimplemented from Component.

Definition at line 1106 of file nnet-component.h.

References Component::Propagate().

1106 { return context_; }
std::vector< int32 > context_

◆ Copy()

Component * Copy ( ) const
virtual

Copy component (deep copy).

Implements Component.

Definition at line 2760 of file nnet-component.cc.

References SpliceComponent::const_component_dim_, SpliceComponent::context_, and SpliceComponent::input_dim_.

2760  {
2761  SpliceComponent *ans = new SpliceComponent();
2762  ans->input_dim_ = input_dim_;
2763  ans->context_ = context_;
2764  ans->const_component_dim_ = const_component_dim_;
2765  return ans;
2766 }
std::vector< int32 > context_

◆ Info()

std::string Info ( ) const
virtual

Reimplemented from Component.

Definition at line 2463 of file nnet-component.cc.

References Component::Info().

2463  {
2464  std::stringstream stream;
2465  std::ostringstream os;
2466  std::copy(context_.begin(), context_.end(),
2467  std::ostream_iterator<int32>(os, " "));
2468  stream << Component::Info() << ", context=" << os.str();
2469  if (const_component_dim_ != 0)
2470  stream << ", const_component_dim=" << const_component_dim_;
2471 
2472  return stream.str();
2473 }
std::vector< int32 > context_
virtual std::string Info() const

◆ Init()

void Init ( int32  input_dim,
std::vector< int32 context,
int32  const_component_dim = 0 
)

Definition at line 2475 of file nnet-component.cc.

References kaldi::IsSortedAndUniq(), and KALDI_ASSERT.

Referenced by kaldi::ConvertSpliceComponent(), kaldi::nnet2::GenRandomNnet(), and kaldi::nnet2::UnitTestSpliceComponent().

2476  {
2477  input_dim_ = input_dim;
2478  const_component_dim_ = const_component_dim;
2479  context_ = context;
2480  KALDI_ASSERT(context_.size() > 0);
2481  KALDI_ASSERT(input_dim_ > 0 && context_.front() <= 0 && context_.back() >= 0);
2482  KALDI_ASSERT(IsSortedAndUniq(context));
2484 }
std::vector< int32 > context_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool IsSortedAndUniq(const std::vector< T > &vec)
Returns true if the vector is sorted and contains each element only once.
Definition: stl-utils.h:63

◆ InitFromString()

void InitFromString ( std::string  args)
virtual

Initialize, typically from a line of a config file.

The "args" will contain any parameters that need to be passed to the Component, e.g. dimensions.

Implements Component.

Definition at line 2488 of file nnet-component.cc.

References rnnlm::i, AffineComponentPreconditionedOnline::Init(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet2::ParseFromString(), and AffineComponentPreconditionedOnline::Type().

2488  {
2489  std::string orig_args(args);
2490  int32 input_dim, left_context, right_context;
2491  std::vector <int32> context;
2492  bool in_dim_ok = ParseFromString("input-dim", &args, &input_dim);
2493  bool context_ok = ParseFromString("context", &args, &context);
2494  bool left_right_context_ok = ParseFromString("left-context", &args,
2495  &left_context) &&
2496  ParseFromString("right-context", &args,
2497  &right_context);
2498  int32 const_component_dim = 0;
2499  ParseFromString("const-component-dim", &args, &const_component_dim);
2500 
2501  if (!(in_dim_ok && (context_ok || left_right_context_ok)) ||
2502  !args.empty() || input_dim <= 0)
2503  KALDI_ERR << "Invalid initializer for layer of type "
2504  << Type() << ": \"" << orig_args << "\"";
2505  if (left_right_context_ok) {
2506  KALDI_ASSERT(context.size() == 0);
2507  for (int32 i = -left_context; i <= right_context; i++)
2508  context.push_back(i);
2509  }
2510  Init(input_dim, context, const_component_dim);
2511 }
void Init(int32 input_dim, std::vector< int32 > context, int32 const_component_dim=0)
kaldi::int32 int32
bool ParseFromString(const std::string &name, std::string *string, int32 *param)
Functions used in Init routines.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual std::string Type() const

◆ InputDim()

virtual int32 InputDim ( ) const
inlinevirtual

Get size of input vectors.

Implements Component.

Definition at line 1104 of file nnet-component.h.

1104 { return input_dim_; }

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( SpliceComponent  )
private

◆ OutputDim()

int32 OutputDim ( ) const
virtual

Get size of output vectors.

Implements Component.

Definition at line 2513 of file nnet-component.cc.

2513  {
2514  return (input_dim_ - const_component_dim_)
2515  * (context_.size())
2517 }
std::vector< int32 > context_

◆ Propagate()

void Propagate ( const ChunkInfo in_info,
const ChunkInfo out_info,
const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
) const
virtual

Perform forward pass propagation Input->Output.

Each row is one frame or training example. Interpreted as "num_chunks" equally sized chunks of frames; this only matters for layers that do things like context splicing. Typically this variable will either be 1 (when we're processing a single contiguous chunk of data) or will be the same as in.NumFrames(), but other values are possible if some layers do splicing.

Implements Component.

Definition at line 2577 of file nnet-component.cc.

References ChunkInfo::Check(), ChunkInfo::CheckSize(), ChunkInfo::ChunkSize(), ChunkInfo::GetIndex(), ChunkInfo::GetOffset(), KALDI_ASSERT, KALDI_ERR, ChunkInfo::NumChunks(), ChunkInfo::NumCols(), CuMatrixBase< Real >::NumCols(), and CuMatrixBase< Real >::NumRows().

2580  {
2581 
2582  // Check the inputs are correct and resize output
2583  in_info.Check();
2584  out_info.Check();
2585  in_info.CheckSize(in);
2586  out_info.CheckSize(*out);
2587  KALDI_ASSERT(in_info.NumChunks() == out_info.NumChunks());
2588 
2589  int32 in_chunk_size = in_info.ChunkSize(),
2590  out_chunk_size = out_info.ChunkSize(),
2591  input_dim = in_info.NumCols();
2592 
2593  if (out_chunk_size <= 0)
2594  KALDI_ERR << "Splicing features: output will have zero dimension. "
2595  << "Probably a code error.";
2596 
2597  // 'indexes' is, for each index from 0 to context_.size() - 1,
2598  // then for each row of "out", the corresponding row of "in" that we copy from
2599  int32 num_splice = context_.size();
2600  std::vector<std::vector<int32> > indexes(num_splice);
2601  for (int32 c = 0; c < num_splice; c++)
2602  indexes[c].resize(out->NumRows());
2603  // const_component_dim_ != 0, "const_indexes" will be used to determine which
2604  // row of "in" we copy the last part of each row of "out" from (this part is
2605  // not subject to splicing, it's assumed constant for each frame of "input".
2606  int32 const_dim = const_component_dim_;
2607  std::vector<int32> const_indexes(const_dim == 0 ? 0 : out->NumRows());
2608 
2609  for (int32 chunk = 0; chunk < in_info.NumChunks(); chunk++) {
2610  if (chunk == 0) {
2611  // this branch could be used for all chunks in the matrix,
2612  // but is restricted to chunk 0 for efficiency reasons
2613  for (int32 c = 0; c < num_splice; c++) {
2614  for (int32 out_index = 0; out_index < out_chunk_size; out_index++) {
2615  int32 out_offset = out_info.GetOffset(out_index);
2616  int32 in_index = in_info.GetIndex(out_offset + context_[c]);
2617  indexes[c][chunk * out_chunk_size + out_index] =
2618  chunk * in_chunk_size + in_index;
2619  }
2620  }
2621  } else { // just copy the indices from the previous chunk
2622  // and offset these by input chunk size
2623  for (int32 c = 0; c < num_splice; c++) {
2624  for (int32 out_index = 0; out_index < out_chunk_size; out_index++) {
2625  int32 last_value = indexes[c][(chunk-1) * out_chunk_size + out_index];
2626  indexes[c][chunk * out_chunk_size + out_index] =
2627  (last_value == -1 ? -1 : last_value + in_chunk_size);
2628  }
2629  }
2630  }
2631  if (const_dim != 0) {
2632  for (int32 out_index = 0; out_index < out_chunk_size; out_index++)
2633  const_indexes[chunk * out_chunk_size + out_index] =
2634  chunk * in_chunk_size + out_index; // there is
2635  // an arbitrariness here; since we assume the const_component
2636  // is constant within a chunk, it doesn't matter from where we copy.
2637  }
2638  }
2639 
2640 
2641  for (int32 c = 0; c < num_splice; c++) {
2642  int32 dim = input_dim - const_dim; // dimension we
2643  // are splicing
2644  CuSubMatrix<BaseFloat> in_part(in, 0, in.NumRows(),
2645  0, dim),
2646  out_part(*out, 0, out->NumRows(),
2647  c * dim, dim);
2648  CuArray<int32> cu_indexes(indexes[c]);
2649  out_part.CopyRows(in_part, cu_indexes);
2650  }
2651  if (const_dim != 0) {
2652  CuSubMatrix<BaseFloat> in_part(in, 0, in.NumRows(),
2653  in.NumCols() - const_dim, const_dim),
2654  out_part(*out, 0, out->NumRows(),
2655  out->NumCols() - const_dim, const_dim);
2656 
2657  CuArray<int32> cu_const_indexes(const_indexes);
2658  out_part.CopyRows(in_part, cu_const_indexes);
2659  }
2660 }
std::vector< int32 > context_
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)
virtual

Implements Component.

Definition at line 2768 of file nnet-component.cc.

References kaldi::nnet2::ExpectOneOrTwoTokens(), kaldi::ExpectToken(), rnnlm::i, KALDI_ERR, kaldi::ReadBasicType(), kaldi::ReadIntegerVector(), and kaldi::ReadToken().

2768  {
2769  ExpectOneOrTwoTokens(is, binary, "<SpliceComponent>", "<InputDim>");
2770  ReadBasicType(is, binary, &input_dim_);
2771  std::string token;
2772  ReadToken(is, false, &token);
2773  if (token == "<LeftContext>") {
2774  int32 left_context=0, right_context=0;
2775  std::vector<int32> context;
2776  ReadBasicType(is, binary, &left_context);
2777  ExpectToken(is, binary, "<RightContext>");
2778  ReadBasicType(is, binary, &right_context);
2779  for (int32 i = -1 * left_context; i <= right_context; i++)
2780  context.push_back(i);
2781  context_ = context;
2782  } else if (token == "<Context>") {
2783  ReadIntegerVector(is, binary, &context_);
2784  } else {
2785  KALDI_ERR << "Unknown token" << token
2786  << ", the model might be corrupted";
2787  }
2788  ExpectToken(is, binary, "<ConstComponentDim>");
2789  ReadBasicType(is, binary, &const_component_dim_);
2790  ExpectToken(is, binary, "</SpliceComponent>");
2791 }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
std::vector< int32 > context_
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
#define KALDI_ERR
Definition: kaldi-error.h:147
static void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)

◆ Type()

virtual std::string Type ( ) const
inlinevirtual

Implements Component.

Definition at line 1101 of file nnet-component.h.

1101 { return "SpliceComponent"; }

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const
virtual

Write component to stream.

Implements Component.

Definition at line 2793 of file nnet-component.cc.

References kaldi::WriteBasicType(), kaldi::WriteIntegerVector(), and kaldi::WriteToken().

2793  {
2794  WriteToken(os, binary, "<SpliceComponent>");
2795  WriteToken(os, binary, "<InputDim>");
2796  WriteBasicType(os, binary, input_dim_);
2797  WriteToken(os, binary, "<Context>");
2798  WriteIntegerVector(os, binary, context_);
2799  WriteToken(os, binary, "<ConstComponentDim>");
2800  WriteBasicType(os, binary, const_component_dim_);
2801  WriteToken(os, binary, "</SpliceComponent>");
2802 }
std::vector< int32 > context_
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ const_component_dim_

int32 const_component_dim_
private

Definition at line 1128 of file nnet-component.h.

Referenced by SpliceComponent::Copy().

◆ context_

std::vector<int32> context_
private

Definition at line 1127 of file nnet-component.h.

Referenced by SpliceComponent::Copy().

◆ input_dim_

int32 input_dim_
private

Definition at line 1126 of file nnet-component.h.

Referenced by SpliceComponent::Copy().


The documentation for this class was generated from the following files: