nnet-nnet.cc
Go to the documentation of this file.
1 // nnet/nnet-nnet.cc
2 
3 // Copyright 2011-2016 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet/nnet-nnet.h"
21 #include "nnet/nnet-component.h"
24 #include "nnet/nnet-activation.h"
26 #include "nnet/nnet-various.h"
27 
28 namespace kaldi {
29 namespace nnet1 {
30 
32 }
33 
35  Destroy();
36 }
37 
38 Nnet::Nnet(const Nnet& other) {
39  // copy the components
40  for (int32 i = 0; i < other.NumComponents(); i++) {
41  components_.push_back(other.GetComponent(i).Copy());
42  }
43  // create empty buffers
44  propagate_buf_.resize(NumComponents()+1);
46  // copy train opts
47  SetTrainOptions(other.opts_);
48  Check();
49 }
50 
51 Nnet& Nnet::operator= (const Nnet& other) {
52  Destroy();
53  // copy the components
54  for (int32 i = 0; i < other.NumComponents(); i++) {
55  components_.push_back(other.GetComponent(i).Copy());
56  }
57  // create empty buffers
58  propagate_buf_.resize(NumComponents()+1);
60  // copy train opts
61  SetTrainOptions(other.opts_);
62  Check();
63  return *this;
64 }
65 
71  CuMatrix<BaseFloat> *out) {
72  // In case of empty network copy input to output,
73  if (NumComponents() == 0) {
74  (*out) = in; // copy,
75  return;
76  }
77  // We need C+1 buffers,
78  if (propagate_buf_.size() != NumComponents()+1) {
79  propagate_buf_.resize(NumComponents()+1);
80  }
81  // Copy input to first buffer,
82  propagate_buf_[0] = in;
83  // Propagate through all the components,
84  for (int32 i = 0; i < static_cast<int32>(components_.size()); i++) {
85  components_[i]->Propagate(propagate_buf_[i], &propagate_buf_[i+1]);
86  }
87  // Copy the output from the last buffer,
88  (*out) = propagate_buf_[NumComponents()];
89 }
90 
91 
97  CuMatrix<BaseFloat> *in_diff) {
98  // Copy the derivative in case of empty network,
99  if (NumComponents() == 0) {
100  (*in_diff) = out_diff; // copy,
101  return;
102  }
103  // We need C+1 buffers,
104  KALDI_ASSERT(static_cast<int32>(propagate_buf_.size()) == NumComponents()+1);
105  if (backpropagate_buf_.size() != NumComponents()+1) {
106  backpropagate_buf_.resize(NumComponents()+1);
107  }
108  // Copy 'out_diff' to last buffer,
109  backpropagate_buf_[NumComponents()] = out_diff;
110  // Loop from last Component to the first,
111  for (int32 i = NumComponents()-1; i >= 0; i--) {
112  // Backpropagate through 'Component',
113  components_[i]->Backpropagate(propagate_buf_[i],
114  propagate_buf_[i+1],
115  backpropagate_buf_[i+1],
116  &backpropagate_buf_[i]);
117  // Update 'Component' (if applicable),
118  if (components_[i]->IsUpdatable()) {
119  UpdatableComponent* uc =
120  dynamic_cast<UpdatableComponent*>(components_[i]);
122  }
123  }
124  // Export the derivative (if applicable),
125  if (NULL != in_diff) {
126  (*in_diff) = backpropagate_buf_[0];
127  }
128 }
129 
130 
132  CuMatrix<BaseFloat> *out) {
133  KALDI_ASSERT(NULL != out);
134  (*out) = in; // works even with 0 components,
135  CuMatrix<BaseFloat> tmp_in;
136  for (int32 i = 0; i < NumComponents(); i++) {
137  out->Swap(&tmp_in);
138  components_[i]->Propagate(tmp_in, out);
139  }
140 }
141 
142 
144  KALDI_ASSERT(!components_.empty());
145  return components_.back()->OutputDim();
146 }
147 
149  KALDI_ASSERT(!components_.empty());
150  return components_.front()->InputDim();
151 }
152 
154  return *(components_.at(c));
155 }
156 
158  return *(components_.at(c));
159 }
160 
162  return *(components_.at(NumComponents()-1));
163 }
164 
166  return *(components_.at(NumComponents()-1));
167 }
168 
169 void Nnet::ReplaceComponent(int32 c, const Component& comp) {
170  delete components_.at(c);
171  components_.at(c) = comp.Copy(); // deep copy,
172  Check();
173 }
174 
176  Component* tmp = components_.at(c);
177  components_.at(c) = *comp;
178  (*comp) = tmp;
179  Check();
180 }
181 
182 void Nnet::AppendComponent(const Component& comp) {
183  components_.push_back(comp.Copy()); // append,
184  Check();
185 }
186 
187 void Nnet::AppendComponentPointer(Component* dynamically_allocated_comp) {
188  components_.push_back(dynamically_allocated_comp); // append,
189  Check();
190 }
191 
192 void Nnet::AppendNnet(const Nnet& other) {
193  for (int32 i = 0; i < other.NumComponents(); i++) {
195  }
196  Check();
197 }
198 
200  Component* ptr = components_.at(c);
201  components_.erase(components_.begin()+c);
202  delete ptr;
203  Check();
204 }
205 
208 }
209 
211  int32 n_params = 0;
212  for (int32 n = 0; n < components_.size(); n++) {
213  if (components_[n]->IsUpdatable()) {
214  n_params +=
215  dynamic_cast<UpdatableComponent*>(components_[n])->NumParams();
216  }
217  }
218  return n_params;
219 }
220 
221 void Nnet::GetGradient(Vector<BaseFloat>* gradient) const {
222  gradient->Resize(NumParams());
223  int32 pos = 0;
224  // loop over Components,
225  for (int32 i = 0; i < components_.size(); i++) {
226  if (components_[i]->IsUpdatable()) {
227  UpdatableComponent& c =
228  dynamic_cast<UpdatableComponent&>(*components_[i]);
229  SubVector<BaseFloat> grad_range(gradient->Range(pos, c.NumParams()));
230  c.GetGradient(&grad_range); // getting gradient,
231  pos += c.NumParams();
232  }
233  }
234  KALDI_ASSERT(pos == NumParams());
235 }
236 
237 void Nnet::GetParams(Vector<BaseFloat>* params) const {
238  params->Resize(NumParams());
239  int32 pos = 0;
240  // loop over Components,
241  for (int32 i = 0; i < components_.size(); i++) {
242  if (components_[i]->IsUpdatable()) {
243  UpdatableComponent& c =
244  dynamic_cast<UpdatableComponent&>(*components_[i]);
245  SubVector<BaseFloat> params_range(params->Range(pos, c.NumParams()));
246  c.GetParams(&params_range); // getting params,
247  pos += c.NumParams();
248  }
249  }
250  KALDI_ASSERT(pos == NumParams());
251 }
252 
254  KALDI_ASSERT(params.Dim() == NumParams());
255  int32 pos = 0;
256  // loop over Components,
257  for (int32 i = 0; i < components_.size(); i++) {
258  if (components_[i]->IsUpdatable()) {
259  UpdatableComponent& c =
260  dynamic_cast<UpdatableComponent&>(*components_[i]);
261  c.SetParams(params.Range(pos, c.NumParams())); // setting params,
262  pos += c.NumParams();
263  }
264  }
265  KALDI_ASSERT(pos == NumParams());
266 }
267 
269  for (int32 c = 0; c < NumComponents(); c++) {
271  Dropout& comp = dynamic_cast<Dropout&>(GetComponent(c));
272  BaseFloat r_old = comp.GetDropoutRate();
273  comp.SetDropoutRate(r);
274  KALDI_LOG << "Setting dropout-rate in component " << c
275  << " from " << r_old << " to " << r;
276  }
277  }
278 }
279 
280 
281 void Nnet::ResetStreams(const std::vector<int32> &stream_reset_flag) {
282  for (int32 c = 0; c < NumComponents(); c++) {
283  if (GetComponent(c).IsMultistream()) {
284  MultistreamComponent& comp =
285  dynamic_cast<MultistreamComponent&>(GetComponent(c));
286  comp.ResetStreams(stream_reset_flag);
287  }
288  }
289 }
290 
291 void Nnet::SetSeqLengths(const std::vector<int32> &sequence_lengths) {
292  for (int32 c = 0; c < NumComponents(); c++) {
293  if (GetComponent(c).IsMultistream()) {
294  MultistreamComponent& comp =
295  dynamic_cast<MultistreamComponent&>(GetComponent(c));
296  comp.SetSeqLengths(sequence_lengths);
297  }
298  }
299 }
300 
301 void Nnet::Init(const std::string &proto_file) {
302  Input in(proto_file);
303  std::istream &is = in.Stream();
304  std::string proto_line, token;
305 
306  // Initialize from the prototype, where each line
307  // contains the description for one component.
308  while (is >> std::ws, !is.eof()) {
309  KALDI_ASSERT(is.good());
310 
311  // get a line from the proto file,
312  std::getline(is, proto_line);
313  if (proto_line == "") continue;
314  KALDI_VLOG(1) << proto_line;
315 
316  // get the 1st token from the line,
317  std::istringstream(proto_line) >> std::ws >> token;
318  // ignore these tokens:
319  if (token == "<NnetProto>" || token == "</NnetProto>") continue;
320 
321  // create new component, append to Nnet,
322  this->AppendComponentPointer(Component::Init(proto_line+"\n"));
323  }
324  // cleanup
325  in.Close();
326  Check();
327 }
328 
329 
333 void Nnet::Read(const std::string &rxfilename) {
334  bool binary;
335  Input in(rxfilename, &binary);
336  Read(in.Stream(), binary);
337  in.Close();
338  // Warn if the NN is empty
339  if (NumComponents() == 0) {
340  KALDI_WARN << "The network '" << rxfilename << "' is empty.";
341  }
342 }
343 
344 
345 void Nnet::Read(std::istream &is, bool binary) {
346  // Read the Components through the 'factory' Component::Read(...),
347  Component* comp(NULL);
348  while (comp = Component::Read(is, binary), comp != NULL) {
349  // Check dims,
350  if (NumComponents() > 0) {
351  if (components_.back()->OutputDim() != comp->InputDim()) {
352  KALDI_ERR << "Dimensionality mismatch!"
353  << " Previous layer output:" << components_.back()->OutputDim()
354  << " Current layer input:" << comp->InputDim();
355  }
356  }
357  // Append to 'this' Nnet,
359  }
360  Check();
361 }
362 
363 
367 void Nnet::Write(const std::string &wxfilename, bool binary) const {
368  Output out(wxfilename, binary, true);
369  Write(out.Stream(), binary);
370  out.Close();
371 }
372 
373 
374 void Nnet::Write(std::ostream &os, bool binary) const {
375  Check();
376  WriteToken(os, binary, "<Nnet>");
377  if (binary == false) os << std::endl;
378  for (int32 i = 0; i < NumComponents(); i++) {
379  components_[i]->Write(os, binary);
380  }
381  WriteToken(os, binary, "</Nnet>");
382  if (binary == false) os << std::endl;
383 }
384 
385 
386 std::string Nnet::Info() const {
387  // global info
388  std::ostringstream ostr;
389  ostr << "num-components " << NumComponents() << std::endl;
390  if (NumComponents() == 0)
391  return ostr.str();
392  ostr << "input-dim " << InputDim() << std::endl;
393  ostr << "output-dim " << OutputDim() << std::endl;
394  ostr << "number-of-parameters " << static_cast<float>(NumParams())/1e6
395  << " millions" << std::endl;
396  // topology & weight stats
397  for (int32 i = 0; i < NumComponents(); i++) {
398  ostr << "component " << i+1 << " : "
399  << Component::TypeToMarker(components_[i]->GetType())
400  << ", input-dim " << components_[i]->InputDim()
401  << ", output-dim " << components_[i]->OutputDim()
402  << ", " << components_[i]->Info() << std::endl;
403  }
404  return ostr.str();
405 }
406 
407 std::string Nnet::InfoGradient(bool header) const {
408  std::ostringstream ostr;
409  // gradient stats
410  if (header) ostr << "\n### GRADIENT STATS :\n";
411  for (int32 i = 0; i < NumComponents(); i++) {
412  ostr << "Component " << i+1 << " : "
413  << Component::TypeToMarker(components_[i]->GetType())
414  << ", " << components_[i]->InfoGradient() << std::endl;
415  }
416  if (header) ostr << "### END GRADIENT\n";
417  return ostr.str();
418 }
419 
420 std::string Nnet::InfoPropagate(bool header) const {
421  std::ostringstream ostr;
422  // forward-pass buffer stats
423  if (header) ostr << "\n### FORWARD PROPAGATION BUFFER CONTENT :\n";
424  ostr << "[0] output of <Input> " << MomentStatistics(propagate_buf_[0])
425  << std::endl;
426  for (int32 i = 0; i < NumComponents(); i++) {
427  ostr << "[" << 1+i << "] output of "
428  << Component::TypeToMarker(components_[i]->GetType())
429  << MomentStatistics(propagate_buf_[i+1]) << std::endl;
430  // nested networks too...
431  if (Component::kParallelComponent == components_[i]->GetType()) {
432  ostr <<
433  dynamic_cast<ParallelComponent*>(components_[i])->InfoPropagate();
434  }
435  if (Component::kMultiBasisComponent == components_[i]->GetType()) {
436  ostr << dynamic_cast<MultiBasisComponent*>(components_[i])->InfoPropagate();
437  }
438  }
439  if (header) ostr << "### END FORWARD\n";
440  return ostr.str();
441 }
442 
443 std::string Nnet::InfoBackPropagate(bool header) const {
444  std::ostringstream ostr;
445  // forward-pass buffer stats
446  if (header) ostr << "\n### BACKWARD PROPAGATION BUFFER CONTENT :\n";
447  ostr << "[0] diff of <Input> " << MomentStatistics(backpropagate_buf_[0])
448  << std::endl;
449  for (int32 i = 0; i < NumComponents(); i++) {
450  ostr << "["<<1+i<< "] diff-output of "
451  << Component::TypeToMarker(components_[i]->GetType())
452  << MomentStatistics(backpropagate_buf_[i+1]) << std::endl;
453  // nested networks too...
454  if (Component::kParallelComponent == components_[i]->GetType()) {
455  ostr <<
456  dynamic_cast<ParallelComponent*>(components_[i])->InfoBackPropagate();
457  }
458  if (Component::kMultiBasisComponent == components_[i]->GetType()) {
459  ostr << dynamic_cast<MultiBasisComponent*>(components_[i])->InfoBackPropagate();
460  }
461  }
462  if (header) ostr << "### END BACKWARD\n\n";
463  return ostr.str();
464 }
465 
466 
467 void Nnet::Check() const {
468  // check dims,
469  for (size_t i = 0; i + 1 < components_.size(); i++) {
470  KALDI_ASSERT(components_[i] != NULL);
471  int32 output_dim = components_[i]->OutputDim(),
472  next_input_dim = components_[i+1]->InputDim();
473  // show error message,
474  if (output_dim != next_input_dim) {
475  KALDI_ERR << "Component dimension mismatch!"
476  << " Output dim of [" << i << "] "
477  << Component::TypeToMarker(components_[i]->GetType())
478  << " is " << output_dim << ". "
479  << "Input dim of next [" << i+1 << "] "
480  << Component::TypeToMarker(components_[i+1]->GetType())
481  << " is " << next_input_dim << ".";
482  }
483  }
484  // check for nan/inf in network weights,
485  Vector<BaseFloat> weights;
486  GetParams(&weights);
487  BaseFloat sum = weights.Sum();
488  if (KALDI_ISINF(sum)) {
489  KALDI_ERR << "'inf' in network parameters "
490  << "(weight explosion, need lower learning rate?)";
491  }
492  if (KALDI_ISNAN(sum)) {
493  KALDI_ERR << "'nan' in network parameters (need lower learning rate?)";
494  }
495 }
496 
497 
499  for (int32 i = 0; i < NumComponents(); i++) {
500  delete components_[i];
501  }
502  components_.resize(0);
503  propagate_buf_.resize(0);
504  backpropagate_buf_.resize(0);
505 }
506 
507 
509  opts_ = opts;
510  // set values to individual components,
511  for (int32 l = 0; l < NumComponents(); l++) {
512  if (GetComponent(l).IsUpdatable()) {
514  }
515  }
516 }
517 
518 
519 } // namespace nnet1
520 } // namespace kaldi
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
Definition: nnet-nnet.cc:96
void RemoveLastComponent()
Remove the last of the Components,.
Definition: nnet-nnet.cc:206
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Reset streams in multi-stream training,.
Definition: nnet-nnet.cc:281
void SetParams(const VectorBase< BaseFloat > &params)
Set the network weights from a supervector,.
Definition: nnet-nnet.cc:253
void AppendComponentPointer(Component *dynamically_allocated_comp)
Append Component* to &#39;this&#39; instance of Nnet by a shallow copy (&#39;this&#39; instance of Nnet over-takes th...
Definition: nnet-nnet.cc:187
void ReplaceComponent(int32 c, const Component &comp)
Replace c&#39;th component in &#39;this&#39; Nnet (deep copy),.
Definition: nnet-nnet.cc:169
virtual void GetGradient(VectorBase< BaseFloat > *gradient) const =0
Get gradient reshaped as a vector,.
virtual int32 NumParams() const =0
Number of trainable parameters,.
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
Definition: nnet-nnet.cc:70
void GetParams(Vector< BaseFloat > *params) const
Get the network weights in a supervector,.
Definition: nnet-nnet.cc:237
int32 NumParams() const
Get the number of parameters in the network,.
Definition: nnet-nnet.cc:210
std::string MomentStatistics(const VectorBase< Real > &vec)
Get a string with statistics of the data in a vector, so we can print them easily.
Definition: nnet-utils.h:63
void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
Set sequence length in LSTM multi-stream training,.
Definition: nnet-nnet.cc:291
int32 NumComponents() const
Returns the number of &#39;Components&#39; which form the NN.
Definition: nnet-nnet.h:66
virtual void SetParams(const VectorBase< BaseFloat > &params)=0
Set the trainable parameters from, reshaped as a vector,.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
#define KALDI_ISINF
Definition: kaldi-math.h:73
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
Definition: nnet-nnet.cc:148
Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hype...
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
static Component * Init(const std::string &conf_line)
Initialize component from a line in config file,.
static Component * Read(std::istream &is, bool binary)
Read the component from a stream (static method),.
const Component & GetLastComponent() const
LastComponent accessor,.
Definition: nnet-nnet.cc:161
void SetDropoutRate(BaseFloat dr)
virtual bool IsUpdatable() const
Check if componeny has &#39;Updatable&#39; interface (trainable components),.
virtual void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
virtual void Update(const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff)=0
Compute gradient and update parameters,.
void GetGradient(Vector< BaseFloat > *gradient) const
Get the gradient stored in the network,.
Definition: nnet-nnet.cc:221
std::istream & Stream()
Definition: kaldi-io.cc:826
virtual void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Optional function to reset the transfer of context (not used for BLSTMs.
std::ostream & Stream()
Definition: kaldi-io.cc:701
int32 OutputDim() const
Dimensionality of network outputs (posteriors | bn-features | etc.),.
Definition: nnet-nnet.cc:143
void SwapComponent(int32 c, Component **comp)
Swap c&#39;th component with the pointer,.
Definition: nnet-nnet.cc:175
void Swap(Matrix< Real > *mat)
Definition: cu-matrix.cc:123
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
Definition: nnet-nnet.cc:443
void AppendNnet(const Nnet &nnet_to_append)
Append other Nnet to the &#39;this&#39; Nnet (copy all its components),.
Definition: nnet-nnet.cc:192
struct rnnlm::@11::@12 n
int32 InputDim() const
Get the dimension of the input,.
virtual Component * Copy() const =0
Copy component (deep copy),.
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
virtual bool IsMultistream() const
Check if component has &#39;Recurrent&#39; interface (trainable and recurrent),.
#define KALDI_WARN
Definition: kaldi-error.h:150
void Check() const
Consistency check,.
Definition: nnet-nnet.cc:467
void Destroy()
Relese the memory,.
Definition: nnet-nnet.cc:498
int32 Close()
Definition: kaldi-io.cc:761
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
Class MultistreamComponent is an extension of UpdatableComponent for recurrent networks, which are trained with parallel sequences.
std::vector< CuMatrix< BaseFloat > > propagate_buf_
Buffers for forward pass (on demand initialization),.
Definition: nnet-nnet.h:173
Real Sum() const
Returns sum of the elements.
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
std::vector< Component * > components_
Vector which contains all the components composing the neural network, the components are for example...
Definition: nnet-nnet.h:170
Matrix for CUDA computing.
Definition: matrix-common.h:69
void Init(const std::string &proto_file)
Initialize the Nnet from the prototype,.
Definition: nnet-nnet.cc:301
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ISNAN
Definition: kaldi-math.h:72
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
virtual void GetParams(VectorBase< BaseFloat > *params) const =0
Get the trainable parameters reshaped as a vector,.
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
Definition: nnet-nnet.cc:407
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
Definition: nnet-nnet.cc:420
std::string Info() const
Create string with human readable description of the nnet,.
Definition: nnet-nnet.cc:386
void RemoveComponent(int32 c)
Remove c&#39;th component,.
Definition: nnet-nnet.cc:199
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
Definition: nnet-nnet.cc:508
Abstract class, building block of the network.
NnetTrainOptions opts_
Option class with hyper-parameters passed to UpdatableComponent(s)
Definition: nnet-nnet.h:178
std::vector< CuMatrix< BaseFloat > > backpropagate_buf_
Buffers for backward pass (on demand initialization),.
Definition: nnet-nnet.h:175
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
bool Close()
Definition: kaldi-io.cc:677
void AppendComponent(const Component &comp)
Append Component to &#39;this&#39; instance of Nnet (deep copy),.
Definition: nnet-nnet.cc:182
Nnet & operator=(const Nnet &other)
Definition: nnet-nnet.cc:51
SubVector< Real > Range(const MatrixIndexT o, const MatrixIndexT l)
Returns a sub-vector of a vector (a range of elements).
Definition: kaldi-vector.h:94