MultiBasisComponent Class Reference

#include <nnet-multibasis-component.h>

Inheritance diagram for MultiBasisComponent:
Collaboration diagram for MultiBasisComponent:

Public Member Functions

 MultiBasisComponent (int32 dim_in, int32 dim_out)
 
 ~MultiBasisComponent ()
 
ComponentCopy () const
 Copy component (deep copy),. More...
 
ComponentType GetType () const
 Get Type Identification of the component,. More...
 
void InitData (std::istream &is)
 Initialize the content of the component by the 'line' from the prototype,. More...
 
void ReadData (std::istream &is, bool binary)
 Reads the component content. More...
 
void WriteData (std::ostream &os, bool binary) const
 Writes the component content. More...
 
NnetGetBasis (int32 id)
 
const NnetGetBasis (int32 id) const
 
int32 NumParams () const
 Number of trainable parameters,. More...
 
void GetGradient (VectorBase< BaseFloat > *gradient) const
 Get gradient reshaped as a vector,. More...
 
void GetParams (VectorBase< BaseFloat > *params) const
 Get the trainable parameters reshaped as a vector,. More...
 
void SetParams (const VectorBase< BaseFloat > &params)
 Set the trainable parameters from, reshaped as a vector,. More...
 
std::string Info () const
 Print some additional info (after <ComponentName> and the dims),. More...
 
std::string InfoGradient () const
 Print some additional info about gradient (after <...> and dims),. More...
 
std::string InfoPropagate () const
 
std::string InfoBackPropagate () const
 
void PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
 Abstract interface for propagation/backpropagation. More...
 
void BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
 Backward pass transformation (to be implemented by descending class...) More...
 
void Update (const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff)
 Compute gradient and update parameters,. More...
 
void SetTrainOptions (const NnetTrainOptions &opts)
 Overriding the default, which was UpdatableComponent::SetTrainOptions(...) More...
 
void SetLearnRateCoef (BaseFloat val)
 Overriding the default, which was UpdatableComponent::SetLearnRateCoef(...) More...
 
void SetBiasLearnRateCoef (BaseFloat val)
 Overriding the default, which was UpdatableComponent::SetBiasLearnRateCoef(...) More...
 
- Public Member Functions inherited from UpdatableComponent
 UpdatableComponent (int32 input_dim, int32 output_dim)
 
virtual ~UpdatableComponent ()
 
bool IsUpdatable () const
 Check if contains trainable parameters,. More...
 
const NnetTrainOptionsGetTrainOptions () const
 Get the training options from the component,. More...
 
- Public Member Functions inherited from Component
 Component (int32 input_dim, int32 output_dim)
 Generic interface of a component,. More...
 
virtual ~Component ()
 
virtual bool IsMultistream () const
 Check if component has 'Recurrent' interface (trainable and recurrent),. More...
 
int32 InputDim () const
 Get the dimension of the input,. More...
 
int32 OutputDim () const
 Get the dimension of the output,. More...
 
void Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
 Perform forward-pass propagation 'in' -> 'out',. More...
 
void Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
 Perform backward-pass propagation 'out_diff' -> 'in_diff'. More...
 
void Write (std::ostream &os, bool binary) const
 Write the component to a stream,. More...
 

Private Attributes

std::vector< Nnetnnet_basis_
 The vector of 'basis' networks (output of basis is combined according to the posterior_ from the selector_) More...
 
std::vector< CuMatrix< BaseFloat > > basis_out_
 
Nnet selector_
 Selector network,. More...
 
BaseFloat selector_lr_coef_
 
CuMatrix< BaseFloatposterior_
 The output of 'selector_',. More...
 
Vector< BaseFloatposterior_sum_
 
BaseFloat threshold_
 Threshold, applied to posterior_sum_, disables the unused basis,. More...
 

Additional Inherited Members

- Public Types inherited from Component
enum  ComponentType {
  kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform,
  kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent,
  kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax,
  kSigmoid, kTanh, kParametricRelu, kDropout,
  kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice,
  kCopy, kTranspose, kBlockLinearity, kAddShift,
  kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent,
  kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent,
  kMultiBasisComponent
}
 Component type identification mechanism,. More...
 
- Static Public Member Functions inherited from Component
static const char * TypeToMarker (ComponentType t)
 Converts component type to marker,. More...
 
static ComponentType MarkerToType (const std::string &s)
 Converts marker to component type (case insensitive),. More...
 
static ComponentInit (const std::string &conf_line)
 Initialize component from a line in config file,. More...
 
static ComponentRead (std::istream &is, bool binary)
 Read the component from a stream (static method),. More...
 
- Static Public Attributes inherited from Component
static const struct key_value kMarkerMap []
 The table with pairs of Component types and markers (defined in nnet-component.cc),. More...
 
- Protected Attributes inherited from UpdatableComponent
NnetTrainOptions opts_
 Option-class with training hyper-parameters,. More...
 
BaseFloat learn_rate_coef_
 Scalar applied to learning rate for weight matrices (to be used in ::Update method),. More...
 
BaseFloat bias_learn_rate_coef_
 Scalar applied to learning rate for bias (to be used in ::Update method),. More...
 
- Protected Attributes inherited from Component
int32 input_dim_
 Data members,. More...
 
int32 output_dim_
 Dimension of the output of the Component,. More...
 

Detailed Description

Definition at line 34 of file nnet-multibasis-component.h.

Constructor & Destructor Documentation

◆ MultiBasisComponent()

MultiBasisComponent ( int32  dim_in,
int32  dim_out 
)
inline

Definition at line 36 of file nnet-multibasis-component.h.

Referenced by MultiBasisComponent::Copy().

36  :
37  UpdatableComponent(dim_in, dim_out),
38  selector_lr_coef_(1.0),
39  threshold_(0.1)
40  { }
UpdatableComponent(int32 input_dim, int32 output_dim)
BaseFloat threshold_
Threshold, applied to posterior_sum_, disables the unused basis,.

◆ ~MultiBasisComponent()

~MultiBasisComponent ( )
inline

Definition at line 42 of file nnet-multibasis-component.h.

43  { }

Member Function Documentation

◆ BackpropagateFnc()

void BackpropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
const CuMatrixBase< BaseFloat > &  out,
const CuMatrixBase< BaseFloat > &  out_diff,
CuMatrixBase< BaseFloat > *  in_diff 
)
inlinevirtual

Backward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 338 of file nnet-multibasis-component.h.

References CuMatrixBase< Real >::AddDiagVecMat(), CuMatrixBase< Real >::AddMat(), Nnet::Backpropagate(), MultiBasisComponent::basis_out_, CuMatrixBase< Real >::ColRange(), CuMatrixBase< Real >::CopyFromMat(), rnnlm::i, Nnet::InputDim(), kaldi::kNoTrans, kaldi::kTrans, MultiBasisComponent::nnet_basis_, CuMatrixBase< Real >::NumRows(), Component::OutputDim(), MultiBasisComponent::posterior_, MultiBasisComponent::posterior_sum_, CuMatrixBase< Real >::Row(), CuMatrixBase< Real >::Scale(), MultiBasisComponent::selector_, MultiBasisComponent::selector_lr_coef_, MultiBasisComponent::threshold_, and CuMatrix< Real >::Transpose().

341  {
342  // dimensions,
343  int32 num_basis = nnet_basis_.size(),
344  num_frames = in.NumRows();
345 
346  // split the in_diff,
347  CuSubMatrix<BaseFloat> in_diff_basis(
348  in_diff->ColRange(0, nnet_basis_[0].InputDim())
349  );
350  CuSubMatrix<BaseFloat> in_diff_selector(
351  in_diff->ColRange(nnet_basis_[0].InputDim(), selector_.InputDim())
352  );
353 
354  // backprop through 'selector',
355  CuMatrix<BaseFloat> selector_out_diff(num_basis, num_frames);
356  for (int32 i = 0; i < num_basis; i++) {
357  if (posterior_sum_(i) > threshold_) {
358  selector_out_diff.Row(i).AddDiagMatMat(1.0, out_diff, kNoTrans, basis_out_[i], kTrans, 0.0);
359  }
360  }
361  selector_out_diff.Transpose();
362  selector_out_diff.Scale(selector_lr_coef_);
363  CuMatrix<BaseFloat> in_diff_selector_tmp;
364  selector_.Backpropagate(selector_out_diff, &in_diff_selector_tmp);
365  in_diff_selector.CopyFromMat(in_diff_selector_tmp);
366 
367  // backprop through 'basis',
368  CuMatrix<BaseFloat> out_diff_scaled(num_frames, OutputDim()),
369  in_diff_basis_tmp;
370  for (int32 i = 0; i < num_basis; i++) {
371  // use only basis with occupancy >0.1,
372  if (posterior_sum_(i) > threshold_) {
373  out_diff_scaled.AddDiagVecMat(1.0, posterior_.Row(i), out_diff, kNoTrans, 0.0);
374  nnet_basis_[i].Backpropagate(out_diff_scaled, &in_diff_basis_tmp);
375  in_diff_basis.AddMat(1.0, in_diff_basis_tmp);
376  }
377  }
378  }
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
std::vector< CuMatrix< BaseFloat > > basis_out_
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
BaseFloat threshold_
Threshold, applied to posterior_sum_, disables the unused basis,.
int32 OutputDim() const
Get the dimension of the output,.
CuMatrix< BaseFloat > posterior_
The output of &#39;selector_&#39;,.

◆ Copy()

Component* Copy ( ) const
inlinevirtual

Copy component (deep copy),.

Implements Component.

Definition at line 45 of file nnet-multibasis-component.h.

References MultiBasisComponent::MultiBasisComponent().

45 { return new MultiBasisComponent(*this); }
MultiBasisComponent(int32 dim_in, int32 dim_out)

◆ GetBasis() [1/2]

Nnet& GetBasis ( int32  id)
inline

Definition at line 196 of file nnet-multibasis-component.h.

References MultiBasisComponent::nnet_basis_.

196 { return nnet_basis_.at(id); }
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...

◆ GetBasis() [2/2]

const Nnet& GetBasis ( int32  id) const
inline

Definition at line 197 of file nnet-multibasis-component.h.

References MultiBasisComponent::nnet_basis_.

197 { return nnet_basis_.at(id); }
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...

◆ GetGradient()

void GetGradient ( VectorBase< BaseFloat > *  gradient) const
inlinevirtual

Get gradient reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 207 of file nnet-multibasis-component.h.

References KALDI_ERR.

207  {
208  KALDI_ERR << "TODO, not yet implemented!";
209  }
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ GetParams()

void GetParams ( VectorBase< BaseFloat > *  params) const
inlinevirtual

Get the trainable parameters reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 211 of file nnet-multibasis-component.h.

References VectorBase< Real >::Dim(), Nnet::GetParams(), rnnlm::i, KALDI_ASSERT, MultiBasisComponent::nnet_basis_, MultiBasisComponent::NumParams(), VectorBase< Real >::Range(), and MultiBasisComponent::selector_.

211  {
212  int32 offset = 0;
213  Vector<BaseFloat> params_tmp;
214  // selector,
215  selector_.GetParams(&params_tmp);
216  params->Range(offset, params_tmp.Dim()).CopyFromVec(params_tmp);
217  offset += params_tmp.Dim();
218  // basis,
219  for (int32 i = 0; i < nnet_basis_.size(); i++) {
220  nnet_basis_[i].GetParams(&params_tmp);
221  params->Range(offset, params_tmp.Dim()).CopyFromVec(params_tmp);
222  offset += params_tmp.Dim();
223  }
224  KALDI_ASSERT(offset == NumParams());
225  }
void GetParams(Vector< BaseFloat > *params) const
Get the network weights in a supervector,.
Definition: nnet-nnet.cc:237
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 NumParams() const
Number of trainable parameters,.

◆ GetType()

ComponentType GetType ( ) const
inlinevirtual

Get Type Identification of the component,.

Implements Component.

Definition at line 46 of file nnet-multibasis-component.h.

References Component::kMultiBasisComponent.

◆ Info()

std::string Info ( ) const
inlinevirtual

Print some additional info (after <ComponentName> and the dims),.

Reimplemented from Component.

Definition at line 240 of file nnet-multibasis-component.h.

References rnnlm::i, Nnet::Info(), MultiBasisComponent::nnet_basis_, and MultiBasisComponent::selector_.

240  {
241  std::ostringstream os;
242  for (int32 i = 0; i < nnet_basis_.size(); i++) {
243  os << "basis_network #" << i+1 << " {\n"
244  << nnet_basis_[i].Info()
245  << "}\n";
246  }
247  os << "\nselector {\n"
248  << selector_.Info()
249  << "}";
250  return os.str();
251  }
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
std::string Info() const
Create string with human readable description of the nnet,.
Definition: nnet-nnet.cc:386

◆ InfoBackPropagate()

std::string InfoBackPropagate ( ) const
inline

Definition at line 283 of file nnet-multibasis-component.h.

References rnnlm::i, Nnet::InfoBackPropagate(), MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_sum_, MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.

283  {
284  std::ostringstream os;
285  for (int32 i = 0; i < nnet_basis_.size(); i++) {
286  if (posterior_sum_(i) > threshold_) {
287  os << "basis_backpropagate #" << i+1 << "{\n"
288  << nnet_basis_[i].InfoBackPropagate(false)
289  << "}\n";
290  }
291  }
292  os << "selector_backpropagate {\n"
293  << selector_.InfoBackPropagate(false)
294  << "}\n";
295  return os.str();
296  }
kaldi::int32 int32
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
BaseFloat threshold_
Threshold, applied to posterior_sum_, disables the unused basis,.

◆ InfoGradient()

std::string InfoGradient ( ) const
inlinevirtual

Print some additional info about gradient (after <...> and dims),.

Reimplemented from Component.

Definition at line 253 of file nnet-multibasis-component.h.

References rnnlm::i, Nnet::InfoGradient(), MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_sum_, MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.

253  {
254  std::ostringstream os;
255  for (int32 i = 0; i < nnet_basis_.size(); i++) {
256  if (posterior_sum_(i) > threshold_) {
257  os << "basis_gradient #" << i+1 << " {\n"
258  << nnet_basis_[i].InfoGradient(false)
259  << "}\n";
260  }
261  }
262  os << "selector_gradient {\n"
263  << selector_.InfoGradient(false)
264  << "}";
265  return os.str();
266  }
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
BaseFloat threshold_
Threshold, applied to posterior_sum_, disables the unused basis,.
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407

◆ InfoPropagate()

std::string InfoPropagate ( ) const
inline

Definition at line 268 of file nnet-multibasis-component.h.

References rnnlm::i, Nnet::InfoPropagate(), MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_sum_, MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.

268  {
269  std::ostringstream os;
270  for (int32 i = 0; i < nnet_basis_.size(); i++) {
271  if (posterior_sum_(i) > threshold_) {
272  os << "basis_propagate #" << i+1 << " {\n"
273  << nnet_basis_[i].InfoPropagate(false)
274  << "}\n";
275  }
276  }
277  os << "selector_propagate {\n"
278  << selector_.InfoPropagate(false)
279  << "}\n";
280  return os.str();
281  }
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
BaseFloat threshold_
Threshold, applied to posterior_sum_, disables the unused basis,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420

◆ InitData()

void InitData ( std::istream &  is)
inlinevirtual

Initialize the content of the component by the 'line' from the prototype,.

Implements UpdatableComponent.

Definition at line 48 of file nnet-multibasis-component.h.

References rnnlm::i, Nnet::Init(), Nnet::InputDim(), Component::InputDim(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, MultiBasisComponent::nnet_basis_, Nnet::OutputDim(), Component::OutputDim(), Nnet::Read(), kaldi::ReadBasicType(), kaldi::ReadToken(), MultiBasisComponent::selector_, MultiBasisComponent::selector_lr_coef_, AffineTransform::SetLinearity(), and MatrixBase< Real >::SetUnit().

48  {
49  // define options,
50  std::string selector_proto;
51  std::string selector_filename;
52  std::string basis_proto;
53  std::string basis_filename;
54  std::vector<std::string> basis_filename_vector;
55 
56  // parse config
57  std::string token;
58  while (is >> std::ws, !is.eof()) {
59  ReadToken(is, false, &token);
60  if (token == "<SelectorProto>") ReadToken(is, false, &selector_proto);
61  else if (token == "<SelectorFilename>") ReadToken(is, false, &selector_filename);
62  else if (token == "<SelectorLearnRateCoef>") ReadBasicType(is, false, &selector_lr_coef_);
63  else if (token == "<BasisProto>") ReadToken(is, false, &basis_proto);
64  else if (token == "<BasisFilename>") ReadToken(is, false, &basis_filename);
65  else if (token == "<BasisFilenameVector>") {
66  while(is >> std::ws, !is.eof()) {
67  std::string file_or_end;
68  ReadToken(is, false, &file_or_end);
69  if (file_or_end == "</BasisFilenameVector>") break;
70  basis_filename_vector.push_back(file_or_end);
71  }
72  } else KALDI_ERR << "Unknown token " << token << ", typo in config?"
73  << " (SelectorProto|SelectorFilename|BasisProto|BasisFilename|BasisFilenameVector)";
74  }
75 
77 
78  // selector,
79  if (selector_proto != "") {
80  KALDI_LOG << "Initializing 'selector' from : " << selector_proto;
81  selector_.Init(selector_proto);
82  }
83  if (selector_filename != "") {
84  KALDI_LOG << "Reading 'selector' from : " << selector_filename;
85  selector_.Read(selector_filename);
86  }
87 
88  // as many empty basis as outputs of the selector,
90  // fill the basis,
91  if (basis_proto != "") {
92  // Initialized from prototype,
93  KALDI_LOG << "Initializing 'basis' from : " << basis_proto;
94  for (int32 i = 0; i < nnet_basis_.size(); i++) {
95  nnet_basis_[i].Init(basis_proto);
96  }
97  } else if (basis_filename != "") {
98  // Load 1 initial basis repeateadly,
99  KALDI_LOG << "Reading 'basis' from : " << basis_filename;
100  for (int32 i = 0; i < nnet_basis_.size(); i++) {
101  nnet_basis_[i].Read(basis_filename);
102  }
103  } else if (basis_filename_vector.size() > 0) {
104  // Read a list of basis functions,
105  if (basis_filename_vector.size() != nnet_basis_.size()) {
106  KALDI_ERR << "We need " << nnet_basis_.size() << " filenames. "
107  << "We got " << basis_filename_vector.size();
108  }
109  for (int32 i = 0; i < nnet_basis_.size(); i++) {
110  KALDI_LOG << "Reading 'basis' from : "
111  << basis_filename_vector[i];
112  nnet_basis_[i].Read(basis_filename_vector[i]);
113  }
114  } else {
115  // Initialize basis by square identity matrix,
116  int32 basis_input_dim = InputDim() - selector_.InputDim();
117  KALDI_LOG << "Initializing 'basis' to Identity <AffineTransform> "
118  << OutputDim() << "x" << basis_input_dim;
119  KALDI_ASSERT(OutputDim() == basis_input_dim); // has to be square!
120  Matrix<BaseFloat> m(OutputDim(), basis_input_dim);
121  m.SetUnit();
122  // wrap identity into AffineTransform,
123  // (bias is vector of zeros),
124  AffineTransform identity_comp(basis_input_dim, OutputDim());
125  identity_comp.SetLinearity(CuMatrix<BaseFloat>(m));
126  //
127  for (int32 i = 0; i < nnet_basis_.size(); i++) {
128  nnet_basis_[i].AppendComponent(identity_comp);
129  }
130  }
131 
132  // check,
133  KALDI_ASSERT(InputDim() == selector_.InputDim() + nnet_basis_[0].InputDim());
135  }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
Definition: nnet-nnet.cc:143
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
int32 InputDim() const
Get the dimension of the input,.
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
void Init(const std::string &proto_file)
Initialize the Nnet from the prototype,.
Definition: nnet-nnet.cc:301
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 OutputDim() const
Get the dimension of the output,.
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ NumParams()

int32 NumParams ( ) const
inlinevirtual

Number of trainable parameters,.

Implements UpdatableComponent.

Definition at line 199 of file nnet-multibasis-component.h.

References rnnlm::i, MultiBasisComponent::nnet_basis_, Nnet::NumParams(), and MultiBasisComponent::selector_.

Referenced by MultiBasisComponent::GetParams(), and MultiBasisComponent::SetParams().

199  {
200  int32 num_params_sum = selector_.NumParams();
201  for (int32 i = 0; i < nnet_basis_.size(); i++) {
202  num_params_sum += nnet_basis_[i].NumParams();
203  }
204  return num_params_sum;
205  }
int32 NumParams() const
Get the number of parameters in the network,.
Definition: nnet-nnet.cc:210
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...

◆ PropagateFnc()

void PropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
)
inlinevirtual

Abstract interface for propagation/backpropagation.

Forward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 298 of file nnet-multibasis-component.h.

References CuVectorBase< Real >::AddColSumMat(), CuMatrixBase< Real >::AddDiagVecMat(), kaldi::ApproxEqual(), MultiBasisComponent::basis_out_, CuMatrixBase< Real >::ColRange(), rnnlm::i, Nnet::InputDim(), KALDI_ASSERT, kaldi::kNoTrans, MultiBasisComponent::nnet_basis_, MultiBasisComponent::posterior_, MultiBasisComponent::posterior_sum_, Nnet::Propagate(), MultiBasisComponent::selector_, and MultiBasisComponent::threshold_.

299  {
300  // dimensions,
301  int32 num_basis = nnet_basis_.size();
302 
303  // make sure we have all the buffers,
304  if (basis_out_.size() != num_basis) {
305  basis_out_.resize(num_basis);
306  }
307 
308  // split the input,
309  const CuSubMatrix<BaseFloat> in_basis(
310  in.ColRange(0, nnet_basis_[0].InputDim())
311  );
312  const CuSubMatrix<BaseFloat> in_selector(
313  in.ColRange(nnet_basis_[0].InputDim(), selector_.InputDim())
314  );
315 
316  // get the 'selector_' posteriors,
317  selector_.Propagate(in_selector, &posterior_);
318  KALDI_ASSERT(posterior_.Row(0).Min() >= 0.0);
319  KALDI_ASSERT(posterior_.Row(0).Max() <= 1.0);
320  KALDI_ASSERT(ApproxEqual(posterior_.Row(0).Sum(), 1.0));
321  posterior_.Transpose(); // trans,
322 
323  // sum 'selector_' posteriors over time,
324  CuVector<BaseFloat> posterior_sum(num_basis);
325  posterior_sum.AddColSumMat(1.0, posterior_, 0.0);
326  posterior_sum_ = Vector<BaseFloat>(posterior_sum);
327 
328  // combine the 'basis' outputs,
329  for (int32 i = 0; i < nnet_basis_.size(); i++) {
330  if (posterior_sum_(i) > threshold_) {
331  // use only basis with occupancy >0.1,
332  nnet_basis_[i].Propagate(in_basis, &basis_out_[i]);
333  out->AddDiagVecMat(1.0, posterior_.Row(i), basis_out_[i], kNoTrans);
334  }
335  }
336  }
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
std::vector< CuMatrix< BaseFloat > > basis_out_
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
BaseFloat threshold_
Threshold, applied to posterior_sum_, disables the unused basis,.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265
CuMatrix< BaseFloat > posterior_
The output of &#39;selector_&#39;,.

◆ ReadData()

void ReadData ( std::istream &  is,
bool  binary 
)
inlinevirtual

Reads the component content.

Reimplemented from Component.

Definition at line 137 of file nnet-multibasis-component.h.

References kaldi::ExpectToken(), rnnlm::i, Nnet::InputDim(), Component::InputDim(), KALDI_ASSERT, KALDI_ERR, MultiBasisComponent::nnet_basis_, Nnet::OutputDim(), Component::OutputDim(), kaldi::Peek(), kaldi::PeekToken(), Nnet::Read(), kaldi::ReadBasicType(), kaldi::ReadToken(), MultiBasisComponent::selector_, and MultiBasisComponent::selector_lr_coef_.

137  {
138  // Read all the '<Tokens>' in arbitrary order,
139  bool end_loop = false;
140  while (!end_loop && '<' == Peek(is, binary)) {
141  std::string token;
142  int first_char = PeekToken(is, binary);
143  switch (first_char) {
144  case 'S': ReadToken(is, false, &token);
145  if (token == "<SelectorLearnRateCoef>") ReadBasicType(is, binary, &selector_lr_coef_);
146  else if (token == "<Selector>") selector_.Read(is, binary);
147  else KALDI_ERR << "Unknown token: " << token;
148  break;
149  case 'N': ExpectToken(is, binary, "<NumBasis>");
150  int32 num_basis;
151  ReadBasicType(is, binary, &num_basis);
152  nnet_basis_.resize(num_basis);
153  for (int32 i = 0; i < num_basis; i++) {
154  int32 dummy;
155  ExpectToken(is, binary, "<Basis>");
156  ReadBasicType(is, binary, &dummy);
157  nnet_basis_[i].Read(is, binary);
158  }
159  break;
160  case '!':
161  ExpectToken(is, binary, "<!EndOfComponent>");
162  end_loop=true;
163  break;
164  default:
165  ReadToken(is, false, &token);
166  KALDI_ERR << "Unknown token: " << token;
167  }
168  }
169 
170  // check,
172  KALDI_ASSERT(InputDim() == selector_.InputDim() + nnet_basis_[0].InputDim());
174  }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
int Peek(std::istream &is, bool binary)
Peek consumes whitespace (if binary == false) and then returns the peek() value of the stream...
Definition: io-funcs.cc:145
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
Definition: nnet-nnet.cc:143
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
int32 InputDim() const
Get the dimension of the input,.
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
int PeekToken(std::istream &is, bool binary)
PeekToken will return the first character of the next token, or -1 if end of file.
Definition: io-funcs.cc:170
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 OutputDim() const
Get the dimension of the output,.

◆ SetBiasLearnRateCoef()

void SetBiasLearnRateCoef ( BaseFloat  val)
inlinevirtual

Overriding the default, which was UpdatableComponent::SetBiasLearnRateCoef(...)

Reimplemented from UpdatableComponent.

Definition at line 419 of file nnet-multibasis-component.h.

References rnnlm::i, rnnlm::j, MultiBasisComponent::nnet_basis_, and UpdatableComponent::SetBiasLearnRateCoef().

419  {
420  // loop over nnets,
421  for (int32 i = 0; i < nnet_basis_.size(); i++) {
422  // loop over components,
423  for (int32 j = 0; j < nnet_basis_[i].NumComponents(); j++) {
424  if (nnet_basis_[i].GetComponent(j).IsUpdatable()) {
425  UpdatableComponent& comp =
426  dynamic_cast<UpdatableComponent&>(nnet_basis_[i].GetComponent(j));
427  // set the value,
428  comp.SetBiasLearnRateCoef(val);
429  }
430  }
431  }
432  }
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
UpdatableComponent(int32 input_dim, int32 output_dim)

◆ SetLearnRateCoef()

void SetLearnRateCoef ( BaseFloat  val)
inlinevirtual

Overriding the default, which was UpdatableComponent::SetLearnRateCoef(...)

Reimplemented from UpdatableComponent.

Definition at line 400 of file nnet-multibasis-component.h.

References rnnlm::i, rnnlm::j, MultiBasisComponent::nnet_basis_, and UpdatableComponent::SetLearnRateCoef().

400  {
401  // loop over nnets,
402  for (int32 i = 0; i < nnet_basis_.size(); i++) {
403  // loop over components,
404  for (int32 j = 0; j < nnet_basis_[i].NumComponents(); j++) {
405  if (nnet_basis_[i].GetComponent(j).IsUpdatable()) {
406  UpdatableComponent& comp =
407  dynamic_cast<UpdatableComponent&>(nnet_basis_[i].GetComponent(j));
408  // set the value,
409  comp.SetLearnRateCoef(val);
410  }
411  }
412  }
413  }
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
UpdatableComponent(int32 input_dim, int32 output_dim)

◆ SetParams()

void SetParams ( const VectorBase< BaseFloat > &  params)
inlinevirtual

Set the trainable parameters from, reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 227 of file nnet-multibasis-component.h.

References rnnlm::i, KALDI_ASSERT, MultiBasisComponent::nnet_basis_, Nnet::NumParams(), MultiBasisComponent::NumParams(), VectorBase< Real >::Range(), MultiBasisComponent::selector_, and Nnet::SetParams().

227  {
228  int32 offset = 0;
229  // selector,
230  selector_.SetParams(params.Range(offset, selector_.NumParams()));
231  offset += selector_.NumParams();
232  // basis,
233  for (int32 i = 0; i < nnet_basis_.size(); i++) {
234  nnet_basis_[i].SetParams(params.Range(offset, nnet_basis_[i].NumParams()));
235  offset += nnet_basis_[i].NumParams();
236  }
237  KALDI_ASSERT(offset == NumParams());
238  }
void SetParams(const VectorBase< BaseFloat > &params)
Set the network weights from a supervector,.
Definition: nnet-nnet.cc:253
int32 NumParams() const
Get the number of parameters in the network,.
Definition: nnet-nnet.cc:210
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 NumParams() const
Number of trainable parameters,.

◆ SetTrainOptions()

void SetTrainOptions ( const NnetTrainOptions opts)
inlinevirtual

Overriding the default, which was UpdatableComponent::SetTrainOptions(...)

Reimplemented from UpdatableComponent.

Definition at line 389 of file nnet-multibasis-component.h.

References rnnlm::i, MultiBasisComponent::nnet_basis_, MultiBasisComponent::selector_, and Nnet::SetTrainOptions().

389  {
391  for (int32 i=0; i<nnet_basis_.size(); i++) {
392  nnet_basis_[i].SetTrainOptions(opts);
393  }
394  }
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508

◆ Update()

void Update ( const CuMatrixBase< BaseFloat > &  input,
const CuMatrixBase< BaseFloat > &  diff 
)
inlinevirtual

Compute gradient and update parameters,.

Implements UpdatableComponent.

Definition at line 380 of file nnet-multibasis-component.h.

381  {
382  { } // do nothing
383  }

◆ WriteData()

void WriteData ( std::ostream &  os,
bool  binary 
) const
inlinevirtual

Writes the component content.

Reimplemented from Component.

Definition at line 176 of file nnet-multibasis-component.h.

References rnnlm::i, MultiBasisComponent::nnet_basis_, MultiBasisComponent::selector_, MultiBasisComponent::selector_lr_coef_, Nnet::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

176  {
177  int32 num_basis = nnet_basis_.size();
178  WriteToken(os, binary, "<SelectorLearnRateCoef>");
179  WriteBasicType(os, binary, selector_lr_coef_);
180  if (!binary) os << "\n\n";
181  WriteToken(os, binary, "<Selector>");
182  if (!binary) os << "\n";
183  selector_.Write(os, binary);
184  if (!binary) os << "\n";
185  WriteToken(os, binary, "<NumBasis>");
186  WriteBasicType(os, binary, num_basis);
187  if (!binary) os << "\n";
188  for (int32 i = 0; i < num_basis; i++) {
189  WriteToken(os, binary, "<Basis>");
190  WriteBasicType(os, binary, i+1);
191  if (!binary) os << "\n";
192  nnet_basis_.at(i).Write(os, binary);
193  }
194  }
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
kaldi::int32 int32
std::vector< Nnet > nnet_basis_
The vector of &#39;basis&#39; networks (output of basis is combined according to the posterior_ from the sele...
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ basis_out_

std::vector<CuMatrix<BaseFloat> > basis_out_
private

◆ nnet_basis_

◆ posterior_

CuMatrix<BaseFloat> posterior_
private

The output of 'selector_',.

Definition at line 445 of file nnet-multibasis-component.h.

Referenced by MultiBasisComponent::BackpropagateFnc(), and MultiBasisComponent::PropagateFnc().

◆ posterior_sum_

◆ selector_

◆ selector_lr_coef_

◆ threshold_


The documentation for this class was generated from the following file: