ParallelComponent Class Reference

#include <nnet-parallel-component.h>

Inheritance diagram for ParallelComponent:
Collaboration diagram for ParallelComponent:

Public Member Functions

 ParallelComponent (int32 dim_in, int32 dim_out)
 
 ~ParallelComponent ()
 
ComponentCopy () const
 Copy component (deep copy),. More...
 
ComponentType GetType () const
 Get Type Identification of the component,. More...
 
const NnetGetNestedNnet (int32 id) const
 
NnetGetNestedNnet (int32 id)
 
void InitData (std::istream &is)
 Initialize the content of the component by the 'line' from the prototype,. More...
 
void ReadData (std::istream &is, bool binary)
 Reads the component content. More...
 
void WriteData (std::ostream &os, bool binary) const
 Writes the component content. More...
 
int32 NumParams () const
 Number of trainable parameters,. More...
 
void GetGradient (VectorBase< BaseFloat > *gradient) const
 Get gradient reshaped as a vector,. More...
 
void GetParams (VectorBase< BaseFloat > *params) const
 Get the trainable parameters reshaped as a vector,. More...
 
void SetParams (const VectorBase< BaseFloat > &params)
 Set the trainable parameters from, reshaped as a vector,. More...
 
std::string Info () const
 Print some additional info (after <ComponentName> and the dims),. More...
 
std::string InfoGradient () const
 Print some additional info about gradient (after <...> and dims),. More...
 
std::string InfoPropagate () const
 
std::string InfoBackPropagate () const
 
void PropagateFnc (const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
 Abstract interface for propagation/backpropagation. More...
 
void BackpropagateFnc (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
 Backward pass transformation (to be implemented by descending class...) More...
 
void Update (const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff)
 Compute gradient and update parameters,. More...
 
void SetTrainOptions (const NnetTrainOptions &opts)
 Overriding the default, which was UpdatableComponent::SetTrainOptions(...) More...
 
void SetLearnRateCoef (BaseFloat val)
 Overriding the default, which was UpdatableComponent::SetLearnRateCoef(...) More...
 
void SetBiasLearnRateCoef (BaseFloat val)
 Overriding the default, which was UpdatableComponent::SetBiasLearnRateCoef(...) More...
 
void SetSeqLengths (const std::vector< int32 > &sequence_lengths)
 Overriding the default, which was MultistreamComponent::SetSeqLengths(...) More...
 
- Public Member Functions inherited from MultistreamComponent
 MultistreamComponent (int32 input_dim, int32 output_dim)
 
bool IsMultistream () const
 Check if component has 'Recurrent' interface (trainable and recurrent),. More...
 
int32 NumStreams () const
 
virtual void ResetStreams (const std::vector< int32 > &stream_reset_flag)
 Optional function to reset the transfer of context (not used for BLSTMs. More...
 
- Public Member Functions inherited from UpdatableComponent
 UpdatableComponent (int32 input_dim, int32 output_dim)
 
virtual ~UpdatableComponent ()
 
bool IsUpdatable () const
 Check if contains trainable parameters,. More...
 
const NnetTrainOptionsGetTrainOptions () const
 Get the training options from the component,. More...
 
- Public Member Functions inherited from Component
 Component (int32 input_dim, int32 output_dim)
 Generic interface of a component,. More...
 
virtual ~Component ()
 
int32 InputDim () const
 Get the dimension of the input,. More...
 
int32 OutputDim () const
 Get the dimension of the output,. More...
 
void Propagate (const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
 Perform forward-pass propagation 'in' -> 'out',. More...
 
void Backpropagate (const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
 Perform backward-pass propagation 'out_diff' -> 'in_diff'. More...
 
void Write (std::ostream &os, bool binary) const
 Write the component to a stream,. More...
 

Private Attributes

std::vector< Nnetnnet_
 

Additional Inherited Members

- Public Types inherited from Component
enum  ComponentType {
  kUnknown = 0x0, kUpdatableComponent = 0x0100, kAffineTransform, kLinearTransform,
  kConvolutionalComponent, kLstmProjected, kBlstmProjected, kRecurrentComponent,
  kActivationFunction = 0x0200, kSoftmax, kHiddenSoftmax, kBlockSoftmax,
  kSigmoid, kTanh, kParametricRelu, kDropout,
  kLengthNormComponent, kTranform = 0x0400, kRbm, kSplice,
  kCopy, kTranspose, kBlockLinearity, kAddShift,
  kRescale, kKlHmm = 0x0800, kSentenceAveragingComponent, kSimpleSentenceAveragingComponent,
  kAveragePoolingComponent, kMaxPoolingComponent, kFramePoolingComponent, kParallelComponent,
  kMultiBasisComponent
}
 Component type identification mechanism,. More...
 
- Static Public Member Functions inherited from Component
static const char * TypeToMarker (ComponentType t)
 Converts component type to marker,. More...
 
static ComponentType MarkerToType (const std::string &s)
 Converts marker to component type (case insensitive),. More...
 
static ComponentInit (const std::string &conf_line)
 Initialize component from a line in config file,. More...
 
static ComponentRead (std::istream &is, bool binary)
 Read the component from a stream (static method),. More...
 
- Static Public Attributes inherited from Component
static const struct key_value kMarkerMap []
 The table with pairs of Component types and markers (defined in nnet-component.cc),. More...
 
- Protected Attributes inherited from MultistreamComponent
std::vector< int32sequence_lengths_
 
- Protected Attributes inherited from UpdatableComponent
NnetTrainOptions opts_
 Option-class with training hyper-parameters,. More...
 
BaseFloat learn_rate_coef_
 Scalar applied to learning rate for weight matrices (to be used in ::Update method),. More...
 
BaseFloat bias_learn_rate_coef_
 Scalar applied to learning rate for bias (to be used in ::Update method),. More...
 
- Protected Attributes inherited from Component
int32 input_dim_
 Data members,. More...
 
int32 output_dim_
 Dimension of the output of the Component,. More...
 

Detailed Description

Definition at line 36 of file nnet-parallel-component.h.

Constructor & Destructor Documentation

◆ ParallelComponent()

ParallelComponent ( int32  dim_in,
int32  dim_out 
)
inline

Definition at line 38 of file nnet-parallel-component.h.

Referenced by ParallelComponent::Copy().

38  :
39  MultistreamComponent(dim_in, dim_out)
40  { }
MultistreamComponent(int32 input_dim, int32 output_dim)

◆ ~ParallelComponent()

~ParallelComponent ( )
inline

Definition at line 42 of file nnet-parallel-component.h.

43  { }

Member Function Documentation

◆ BackpropagateFnc()

void BackpropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
const CuMatrixBase< BaseFloat > &  out,
const CuMatrixBase< BaseFloat > &  out_diff,
CuMatrixBase< BaseFloat > *  in_diff 
)
inlinevirtual

Backward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 264 of file nnet-parallel-component.h.

References CuMatrixBase< Real >::ColRange(), CuMatrixBase< Real >::CopyFromMat(), rnnlm::i, and ParallelComponent::nnet_.

267  {
268  // column-offsets for data buffers 'in,out',
269  int32 input_offset = 0, output_offset = 0;
270  // loop over nnets,
271  for (int32 i = 0; i < nnet_.size(); i++) {
272  // get the data 'windows',
273  CuSubMatrix<BaseFloat> src(
274  out_diff.ColRange(output_offset, nnet_[i].OutputDim())
275  );
276  CuSubMatrix<BaseFloat> tgt(
277  in_diff->ColRange(input_offset, nnet_[i].InputDim())
278  );
279  // ::Backpropagate through auxiliary matrix (CuMatrix in the interface),
280  CuMatrix<BaseFloat> tgt_aux;
281  nnet_[i].Backpropagate(src, &tgt_aux);
282  tgt.CopyFromMat(tgt_aux);
283  // advance the offsets,
284  input_offset += nnet_[i].InputDim();
285  output_offset += nnet_[i].OutputDim();
286  }
287  }
kaldi::int32 int32

◆ Copy()

Component* Copy ( ) const
inlinevirtual

Copy component (deep copy),.

Implements Component.

Definition at line 45 of file nnet-parallel-component.h.

References ParallelComponent::ParallelComponent().

45 { return new ParallelComponent(*this); }
ParallelComponent(int32 dim_in, int32 dim_out)

◆ GetGradient()

void GetGradient ( VectorBase< BaseFloat > *  gradient) const
inlinevirtual

Get gradient reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 158 of file nnet-parallel-component.h.

References VectorBase< Real >::Dim(), rnnlm::i, KALDI_ASSERT, ParallelComponent::nnet_, ParallelComponent::NumParams(), and VectorBase< Real >::Range().

158  {
159  KALDI_ASSERT(gradient->Dim() == NumParams());
160  int32 offset = 0;
161  for (int32 i = 0; i < nnet_.size(); i++) {
162  int32 n_params = nnet_[i].NumParams();
163  Vector<BaseFloat> gradient_aux; // we need 'Vector<>',
164  nnet_[i].GetGradient(&gradient_aux); // copy gradient from Nnet,
165  gradient->Range(offset, n_params).CopyFromVec(gradient_aux);
166  offset += n_params;
167  }
168  KALDI_ASSERT(offset == NumParams());
169  }
int32 NumParams() const
Number of trainable parameters,.
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetNestedNnet() [1/2]

const Nnet& GetNestedNnet ( int32  id) const
inline

Definition at line 48 of file nnet-parallel-component.h.

References ParallelComponent::nnet_.

48 { return nnet_.at(id); }

◆ GetNestedNnet() [2/2]

Nnet& GetNestedNnet ( int32  id)
inline

Definition at line 49 of file nnet-parallel-component.h.

References ParallelComponent::nnet_.

49 { return nnet_.at(id); }

◆ GetParams()

void GetParams ( VectorBase< BaseFloat > *  params) const
inlinevirtual

Get the trainable parameters reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 171 of file nnet-parallel-component.h.

References VectorBase< Real >::Dim(), rnnlm::i, KALDI_ASSERT, ParallelComponent::nnet_, ParallelComponent::NumParams(), and VectorBase< Real >::Range().

171  {
172  KALDI_ASSERT(params->Dim() == NumParams());
173  int32 offset = 0;
174  for (int32 i = 0; i < nnet_.size(); i++) {
175  int32 n_params = nnet_[i].NumParams();
176  Vector<BaseFloat> params_aux; // we need 'Vector<>',
177  nnet_[i].GetParams(&params_aux); // copy params from Nnet,
178  params->Range(offset, n_params).CopyFromVec(params_aux);
179  offset += n_params;
180  }
181  KALDI_ASSERT(offset == NumParams());
182  }
int32 NumParams() const
Number of trainable parameters,.
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetType()

ComponentType GetType ( ) const
inlinevirtual

Get Type Identification of the component,.

Implements Component.

Definition at line 46 of file nnet-parallel-component.h.

References Component::kParallelComponent.

◆ Info()

std::string Info ( ) const
inlinevirtual

Print some additional info (after <ComponentName> and the dims),.

Reimplemented from Component.

Definition at line 195 of file nnet-parallel-component.h.

References rnnlm::i, and ParallelComponent::nnet_.

195  {
196  std::ostringstream os;
197  os << "\n";
198  for (int32 i = 0; i < nnet_.size(); i++) {
199  os << "nested_network #" << i+1 << " {\n"
200  << nnet_[i].Info()
201  << "}\n";
202  }
203  std::string s(os.str());
204  s.erase(s.end() -1); // removing last '\n'
205  return s;
206  }
kaldi::int32 int32

◆ InfoBackPropagate()

std::string InfoBackPropagate ( ) const
inline

Definition at line 231 of file nnet-parallel-component.h.

References rnnlm::i, and ParallelComponent::nnet_.

231  {
232  std::ostringstream os;
233  for (int32 i = 0; i < nnet_.size(); i++) {
234  os << "nested_backpropagate #" << i+1 << " {\n"
235  << nnet_[i].InfoBackPropagate(false)
236  << "}\n";
237  }
238  return os.str();
239  }
kaldi::int32 int32

◆ InfoGradient()

std::string InfoGradient ( ) const
inlinevirtual

Print some additional info about gradient (after <...> and dims),.

Reimplemented from Component.

Definition at line 208 of file nnet-parallel-component.h.

References rnnlm::i, and ParallelComponent::nnet_.

208  {
209  std::ostringstream os;
210  os << "\n";
211  for (int32 i = 0; i < nnet_.size(); i++) {
212  os << "nested_gradient #" << i+1 << " {\n"
213  << nnet_[i].InfoGradient(false)
214  << "}\n";
215  }
216  std::string s(os.str());
217  s.erase(s.end() -1); // removing last '\n'
218  return s;
219  }
kaldi::int32 int32

◆ InfoPropagate()

std::string InfoPropagate ( ) const
inline

Definition at line 221 of file nnet-parallel-component.h.

References rnnlm::i, and ParallelComponent::nnet_.

221  {
222  std::ostringstream os;
223  for (int32 i = 0; i < nnet_.size(); i++) {
224  os << "nested_propagate #" << i+1 << " {\n"
225  << nnet_[i].InfoPropagate(false)
226  << "}\n";
227  }
228  return os.str();
229  }
kaldi::int32 int32

◆ InitData()

void InitData ( std::istream &  is)
inlinevirtual

Initialize the content of the component by the 'line' from the prototype,.

Implements UpdatableComponent.

Definition at line 51 of file nnet-parallel-component.h.

References rnnlm::i, Nnet::Init(), Component::InputDim(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, ParallelComponent::nnet_, Component::OutputDim(), Nnet::Read(), and kaldi::ReadToken().

51  {
52  // define options
53  std::vector<std::string> nested_nnet_proto;
54  std::vector<std::string> nested_nnet_filename;
55  // parse config
56  std::string token;
57  while (is >> std::ws, !is.eof()) {
58  ReadToken(is, false, &token);
59  if (token == "<NestedNnet>" || token == "<NestedNnetFilename>") {
60  while (is >> std::ws, !is.eof()) {
61  std::string file_or_end;
62  ReadToken(is, false, &file_or_end);
63  if (file_or_end == "</NestedNnet>" ||
64  file_or_end == "</NestedNnetFilename>") break;
65  nested_nnet_filename.push_back(file_or_end);
66  }
67  } else if (token == "<NestedNnetProto>") {
68  while (is >> std::ws, !is.eof()) {
69  std::string file_or_end;
70  ReadToken(is, false, &file_or_end);
71  if (file_or_end == "</NestedNnetProto>") break;
72  nested_nnet_proto.push_back(file_or_end);
73  }
74  } else { KALDI_ERR << "Unknown token " << token << ", typo in config?"
75  << " (NestedNnet|NestedNnetFilename|NestedNnetProto)";
76  }
77  }
78  // Initialize,
79  // First, read nnets from files,
80  if (nested_nnet_filename.size() > 0) {
81  for (int32 i = 0; i < nested_nnet_filename.size(); i++) {
82  Nnet nnet;
83  nnet.Read(nested_nnet_filename[i]);
84  nnet_.push_back(nnet);
85  KALDI_LOG << "Loaded nested <Nnet> from file : "
86  << nested_nnet_filename[i];
87  }
88  }
89  // Second, initialize nnets from prototypes,
90  if (nested_nnet_proto.size() > 0) {
91  for (int32 i = 0; i < nested_nnet_proto.size(); i++) {
92  Nnet nnet;
93  nnet.Init(nested_nnet_proto[i]);
94  nnet_.push_back(nnet);
95  KALDI_LOG << "Initialized nested <Nnet> from prototype : "
96  << nested_nnet_proto[i];
97  }
98  }
99  // Check dim-sum of nested nnets,
100  int32 nnet_input_sum = 0, nnet_output_sum = 0;
101  for (int32 i = 0; i < nnet_.size(); i++) {
102  nnet_input_sum += nnet_[i].InputDim();
103  nnet_output_sum += nnet_[i].OutputDim();
104  }
105  KALDI_ASSERT(InputDim() == nnet_input_sum);
106  KALDI_ASSERT(OutputDim() == nnet_output_sum);
107  }
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
int32 InputDim() const
Get the dimension of the input,.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 OutputDim() const
Get the dimension of the output,.
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ NumParams()

int32 NumParams ( ) const
inlinevirtual

Number of trainable parameters,.

Implements UpdatableComponent.

Definition at line 150 of file nnet-parallel-component.h.

References rnnlm::i, and ParallelComponent::nnet_.

Referenced by ParallelComponent::GetGradient(), ParallelComponent::GetParams(), and ParallelComponent::SetParams().

150  {
151  int32 ans = 0;
152  for (int32 i = 0; i < nnet_.size(); i++) {
153  ans += nnet_[i].NumParams();
154  }
155  return ans;
156  }
kaldi::int32 int32

◆ PropagateFnc()

void PropagateFnc ( const CuMatrixBase< BaseFloat > &  in,
CuMatrixBase< BaseFloat > *  out 
)
inlinevirtual

Abstract interface for propagation/backpropagation.

Forward pass transformation (to be implemented by descending class...)

Implements Component.

Definition at line 241 of file nnet-parallel-component.h.

References CuMatrixBase< Real >::ColRange(), CuMatrixBase< Real >::CopyFromMat(), rnnlm::i, and ParallelComponent::nnet_.

242  {
243  // column-offsets for data buffers 'in,out',
244  int32 input_offset = 0, output_offset = 0;
245  // loop over nnets,
246  for (int32 i = 0; i < nnet_.size(); i++) {
247  // get the data 'windows',
248  CuSubMatrix<BaseFloat> src(
249  in.ColRange(input_offset, nnet_[i].InputDim())
250  );
251  CuSubMatrix<BaseFloat> tgt(
252  out->ColRange(output_offset, nnet_[i].OutputDim())
253  );
254  // forward through auxiliary matrix, as 'Propagate' requires 'CuMatrix',
255  CuMatrix<BaseFloat> tgt_aux;
256  nnet_[i].Propagate(src, &tgt_aux);
257  tgt.CopyFromMat(tgt_aux);
258  // advance the offsets,
259  input_offset += nnet_[i].InputDim();
260  output_offset += nnet_[i].OutputDim();
261  }
262  }
kaldi::int32 int32

◆ ReadData()

void ReadData ( std::istream &  is,
bool  binary 
)
inlinevirtual

Reads the component content.

Reimplemented from Component.

Definition at line 109 of file nnet-parallel-component.h.

References kaldi::ExpectToken(), rnnlm::i, Component::InputDim(), KALDI_ASSERT, ParallelComponent::nnet_, Component::OutputDim(), Nnet::Read(), and kaldi::ReadBasicType().

109  {
110  // read
111  ExpectToken(is, binary, "<NestedNnetCount>");
112  int32 nnet_count;
113  ReadBasicType(is, binary, &nnet_count);
114  for (int32 i = 0; i < nnet_count; i++) {
115  ExpectToken(is, binary, "<NestedNnet>");
116  int32 dummy;
117  ReadBasicType(is, binary, &dummy);
118  Nnet nnet;
119  nnet.Read(is, binary);
120  nnet_.push_back(nnet);
121  }
122  ExpectToken(is, binary, "</ParallelComponent>");
123 
124  // check dim-sum of nested nnets
125  int32 nnet_input_sum = 0, nnet_output_sum = 0;
126  for (int32 i = 0; i < nnet_.size(); i++) {
127  nnet_input_sum += nnet_[i].InputDim();
128  nnet_output_sum += nnet_[i].OutputDim();
129  }
130  KALDI_ASSERT(InputDim() == nnet_input_sum);
131  KALDI_ASSERT(OutputDim() == nnet_output_sum);
132  }
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
int32 InputDim() const
Get the dimension of the input,.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 OutputDim() const
Get the dimension of the output,.

◆ SetBiasLearnRateCoef()

void SetBiasLearnRateCoef ( BaseFloat  val)
inlinevirtual

Overriding the default, which was UpdatableComponent::SetBiasLearnRateCoef(...)

Reimplemented from UpdatableComponent.

Definition at line 327 of file nnet-parallel-component.h.

References rnnlm::i, rnnlm::j, ParallelComponent::nnet_, and UpdatableComponent::SetBiasLearnRateCoef().

327  {
328  // loop over nnets,
329  for (int32 i = 0; i < nnet_.size(); i++) {
330  // loop over components,
331  for (int32 j = 0; j < nnet_[i].NumComponents(); j++) {
332  if (nnet_[i].GetComponent(j).IsUpdatable()) {
333  UpdatableComponent& comp =
334  dynamic_cast<UpdatableComponent&>(nnet_[i].GetComponent(j));
335  // set the value,
336  comp.SetBiasLearnRateCoef(val);
337  }
338  }
339  }
340  }
kaldi::int32 int32
UpdatableComponent(int32 input_dim, int32 output_dim)

◆ SetLearnRateCoef()

void SetLearnRateCoef ( BaseFloat  val)
inlinevirtual

Overriding the default, which was UpdatableComponent::SetLearnRateCoef(...)

Reimplemented from UpdatableComponent.

Definition at line 308 of file nnet-parallel-component.h.

References rnnlm::i, rnnlm::j, ParallelComponent::nnet_, and UpdatableComponent::SetLearnRateCoef().

308  {
309  // loop over nnets,
310  for (int32 i = 0; i < nnet_.size(); i++) {
311  // loop over components,
312  for (int32 j = 0; j < nnet_[i].NumComponents(); j++) {
313  if (nnet_[i].GetComponent(j).IsUpdatable()) {
314  UpdatableComponent& comp =
315  dynamic_cast<UpdatableComponent&>(nnet_[i].GetComponent(j));
316  // set the value,
317  comp.SetLearnRateCoef(val);
318  }
319  }
320  }
321  }
kaldi::int32 int32
UpdatableComponent(int32 input_dim, int32 output_dim)

◆ SetParams()

void SetParams ( const VectorBase< BaseFloat > &  params)
inlinevirtual

Set the trainable parameters from, reshaped as a vector,.

Implements UpdatableComponent.

Definition at line 184 of file nnet-parallel-component.h.

References VectorBase< Real >::Dim(), rnnlm::i, KALDI_ASSERT, ParallelComponent::nnet_, ParallelComponent::NumParams(), and VectorBase< Real >::Range().

184  {
185  KALDI_ASSERT(params.Dim() == NumParams());
186  int32 offset = 0;
187  for (int32 i = 0; i < nnet_.size(); i++) {
188  int32 n_params = nnet_[i].NumParams();
189  nnet_[i].SetParams(params.Range(offset, n_params));
190  offset += n_params;
191  }
192  KALDI_ASSERT(offset == NumParams());
193  }
int32 NumParams() const
Number of trainable parameters,.
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetSeqLengths()

void SetSeqLengths ( const std::vector< int32 > &  sequence_lengths)
inlinevirtual

Overriding the default, which was MultistreamComponent::SetSeqLengths(...)

Reimplemented from MultistreamComponent.

Definition at line 346 of file nnet-parallel-component.h.

References rnnlm::i, ParallelComponent::nnet_, and MultistreamComponent::sequence_lengths_.

346  {
347  sequence_lengths_ = sequence_lengths;
348  // loop over nnets,
349  for (int32 i = 0; i < nnet_.size(); i++) {
350  nnet_[i].SetSeqLengths(sequence_lengths);
351  }
352  }
kaldi::int32 int32
std::vector< int32 > sequence_lengths_

◆ SetTrainOptions()

void SetTrainOptions ( const NnetTrainOptions opts)
inlinevirtual

Overriding the default, which was UpdatableComponent::SetTrainOptions(...)

Reimplemented from UpdatableComponent.

Definition at line 298 of file nnet-parallel-component.h.

References rnnlm::i, and ParallelComponent::nnet_.

298  {
299  for (int32 i = 0; i < nnet_.size(); i++) {
300  nnet_[i].SetTrainOptions(opts);
301  }
302  }
kaldi::int32 int32

◆ Update()

void Update ( const CuMatrixBase< BaseFloat > &  input,
const CuMatrixBase< BaseFloat > &  diff 
)
inlinevirtual

Compute gradient and update parameters,.

Implements UpdatableComponent.

Definition at line 289 of file nnet-parallel-component.h.

290  {
291  { } // do nothing
292  }

◆ WriteData()

void WriteData ( std::ostream &  os,
bool  binary 
) const
inlinevirtual

Writes the component content.

Reimplemented from Component.

Definition at line 134 of file nnet-parallel-component.h.

References rnnlm::i, ParallelComponent::nnet_, kaldi::WriteBasicType(), and kaldi::WriteToken().

134  {
135  // useful dims
136  int32 nnet_count = nnet_.size();
137  //
138  WriteToken(os, binary, "<NestedNnetCount>");
139  WriteBasicType(os, binary, nnet_count);
140  if (!binary) os << "\n";
141  for (int32 i = 0; i < nnet_count; i++) {
142  WriteToken(os, binary, "<NestedNnet>");
143  WriteBasicType(os, binary, i+1);
144  if (!binary) os << "\n";
145  nnet_[i].Write(os, binary);
146  }
147  WriteToken(os, binary, "</ParallelComponent>");
148  }
kaldi::int32 int32
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ nnet_


The documentation for this class was generated from the following file: