All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
FisherComputationClass Class Reference
Inheritance diagram for FisherComputationClass:
Collaboration diagram for FisherComputationClass:

Public Member Functions

 FisherComputationClass (const Nnet &nnet, const std::vector< Nnet > &nnets, const std::vector< NnetExample > &egs, int32 minibatch_size, SpMatrix< double > *scatter)
 
 FisherComputationClass (const FisherComputationClass &other)
 
void operator() ()
 
 ~FisherComputationClass ()
 
- Public Member Functions inherited from MultiThreadable
virtual ~MultiThreadable ()
 

Private Attributes

const Nnetnnet_
 
const std::vector< Nnet > & nnets_
 
const std::vector< NnetExample > & egs_
 
int32 minibatch_size_
 
SpMatrix< double > * scatter_ptr_
 
SpMatrix< double > scatter_
 

Additional Inherited Members

- Public Attributes inherited from MultiThreadable
int32 thread_id_
 
int32 num_threads_
 

Detailed Description

Definition at line 31 of file combine-nnet-fast.cc.

Constructor & Destructor Documentation

FisherComputationClass ( const Nnet nnet,
const std::vector< Nnet > &  nnets,
const std::vector< NnetExample > &  egs,
int32  minibatch_size,
SpMatrix< double > *  scatter 
)
inline

Definition at line 33 of file combine-nnet-fast.cc.

37  :
38  nnet_(nnet), nnets_(nnets), egs_(egs), minibatch_size_(minibatch_size),
39  scatter_ptr_(scatter) { } // This initializer is only used to create a
const std::vector< Nnet > & nnets_
const std::vector< NnetExample > & egs_

Definition at line 43 of file combine-nnet-fast.cc.

References FisherComputationClass::nnet_, FisherComputationClass::nnets_, Nnet::NumUpdatableComponents(), SpMatrix< Real >::Resize(), and FisherComputationClass::scatter_.

43  :
44  MultiThreadable(other),
45  nnet_(other.nnet_), nnets_(other.nnets_), egs_(other.egs_),
46  minibatch_size_(other.minibatch_size_), scatter_ptr_(other.scatter_ptr_) {
const std::vector< Nnet > & nnets_
void Resize(MatrixIndexT nRows, MatrixResizeType resize_type=kSetZero)
Definition: sp-matrix.h:81
int32 NumUpdatableComponents() const
Returns the number of updatable components.
Definition: nnet-nnet.cc:413
const std::vector< NnetExample > & egs_

Definition at line 82 of file combine-nnet-fast.cc.

References SpMatrix< Real >::AddSp(), PackedMatrix< Real >::NumRows(), SpMatrix< Real >::Resize(), FisherComputationClass::scatter_, and FisherComputationClass::scatter_ptr_.

82  {
83  if (scatter_.NumRows() != 0) {
84  if (scatter_ptr_->NumRows() == 0)
87  }
88  }
void AddSp(const Real alpha, const SpMatrix< Real > &Ma)
Definition: sp-matrix.h:211
MatrixIndexT NumRows() const
void Resize(MatrixIndexT nRows, MatrixResizeType resize_type=kSetZero)
Definition: sp-matrix.h:81

Member Function Documentation

void operator() ( )
inlinevirtual

Implements MultiThreadable.

Definition at line 49 of file combine-nnet-fast.cc.

References SpMatrix< Real >::AddVec2(), VectorBase< Real >::Dim(), kaldi::nnet2::DoBackprop(), UpdatableComponent::DotProduct(), FisherComputationClass::egs_, Nnet::GetComponent(), rnnlm::i, KALDI_ASSERT, FisherComputationClass::minibatch_size_, rnnlm::n, FisherComputationClass::nnet_, FisherComputationClass::nnets_, MultiThreadable::num_threads_, Nnet::NumComponents(), Nnet::NumUpdatableComponents(), FisherComputationClass::scatter_, Nnet::SetZero(), and MultiThreadable::thread_id_.

49  {
50  // b is the "minibatch id."
51  int32 num_egs = static_cast<int32>(egs_.size());
52  Nnet nnet_gradient(nnet_);
53  for (int32 b = 0; b * minibatch_size_ < num_egs; b++) {
54  if (b % num_threads_ != thread_id_)
55  continue; // We're not responsible for this minibatch.
56  int32 offset = b * minibatch_size_,
57  length = std::min(minibatch_size_,
58  num_egs - offset);
59  bool is_gradient = true;
60  nnet_gradient.SetZero(is_gradient);
61  std::vector<NnetExample> minibatch(egs_.begin() + offset,
62  egs_.begin() + offset + length);
63  DoBackprop(nnet_, minibatch, &nnet_gradient);
64  Vector<double> gradient(nnets_.size() * nnet_.NumUpdatableComponents());
65  int32 i = 0;
66  for (int32 n = 0; n < static_cast<int32>(nnets_.size()); n++) {
67  for (int32 c = 0; c < nnet_.NumComponents(); c++) {
68  const UpdatableComponent *uc = dynamic_cast<const UpdatableComponent*>(
69  &(nnet_gradient.GetComponent(c))),
70  *uc_other = dynamic_cast<const UpdatableComponent*>(
71  &(nnets_[n].GetComponent(c)));
72  if (uc != NULL) {
73  gradient(i) = uc->DotProduct(*uc_other);
74  i++;
75  }
76  }
77  }
78  KALDI_ASSERT(i == gradient.Dim());
79  scatter_.AddVec2(1.0, gradient);
80  }
81  }
const std::vector< Nnet > & nnets_
double DoBackprop(const Nnet &nnet, const std::vector< NnetExample > &examples, Nnet *nnet_to_update, double *tot_accuracy)
This function computes the objective function and either updates the model or adds to parameter gradi...
Definition: nnet-update.cc:265
struct rnnlm::@11::@12 n
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
void AddVec2(const Real alpha, const VectorBase< OtherReal > &v)
rank-one update, this <– this + alpha v v'
Definition: sp-matrix.cc:946
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 NumUpdatableComponents() const
Returns the number of updatable components.
Definition: nnet-nnet.cc:413
const std::vector< NnetExample > & egs_

Member Data Documentation

const std::vector<NnetExample>& egs_
private

Definition at line 94 of file combine-nnet-fast.cc.

Referenced by FisherComputationClass::operator()().

int32 minibatch_size_
private

Definition at line 95 of file combine-nnet-fast.cc.

Referenced by FisherComputationClass::operator()().

const Nnet& nnet_
private
const std::vector<Nnet>& nnets_
private
SpMatrix<double>* scatter_ptr_
private

The documentation for this class was generated from the following file: