#include <nnet-update.h>
Public Member Functions | |
NnetUpdater (const Nnet &nnet, Nnet *nnet_to_update) | |
double | ComputeForMinibatch (const std::vector< NnetExample > &data, double *tot_accuracy) |
Does the entire forward and backward computation for this minbatch. More... | |
double | ComputeForMinibatch (const std::vector< NnetExample > &data, Matrix< BaseFloat > *formatted_data, double *tot_accuracy) |
This version of ComputeForMinibatch is used when you have already called the function FormatNnetInput (defined below) to format your data as a single matrix. More... | |
void | GetOutput (CuMatrix< BaseFloat > *output) |
Protected Member Functions | |
void | Propagate () |
void | FormatInput (const std::vector< NnetExample > &data) |
Formats the input as a single matrix and sets the size of forward_data_, and sets up chunk_info_out_. More... | |
double | ComputeObjfAndDeriv (const std::vector< NnetExample > &data, CuMatrix< BaseFloat > *deriv, double *tot_accuracy=NULL) const |
Computes objective function and derivative at output layer, but does not do the backprop [for that, see Backprop()]. More... | |
void | Backprop (CuMatrix< BaseFloat > *deriv) const |
Backprop must be called after ComputeObjfAndDeriv. More... | |
Private Member Functions | |
double | ComputeTotAccuracy (const std::vector< NnetExample > &data) const |
Private Attributes | |
const Nnet & | nnet_ |
Nnet * | nnet_to_update_ |
int32 | num_chunks_ |
std::vector< ChunkInfo > | chunk_info_out_ |
std::vector< CuMatrix< BaseFloat > > | forward_data_ |
Friends | |
class | NnetEnsembleTrainer |
Definition at line 46 of file nnet-update.h.
NnetUpdater | ( | const Nnet & | nnet, |
Nnet * | nnet_to_update | ||
) |
Definition at line 28 of file nnet-update.cc.
Backprop must be called after ComputeObjfAndDeriv.
Does the backpropagation; "nnet_to_update_" is updated. Note: "deriv" will contain, at input, the derivative w.r.t. the output layer (as computed by ComputeObjfAndDeriv), but will be used as a temporary variable by this function.
Definition at line 188 of file nnet-update.cc.
References Component::Backprop(), NnetUpdater::chunk_info_out_, Nnet::FirstUpdatableComponent(), NnetUpdater::forward_data_, Nnet::GetComponent(), NnetUpdater::nnet_, NnetUpdater::nnet_to_update_, CuMatrixBase< Real >::NumCols(), Nnet::NumComponents(), and CuMatrixBase< Real >::NumRows().
Referenced by NnetUpdater::ComputeForMinibatch().
double ComputeForMinibatch | ( | const std::vector< NnetExample > & | data, |
double * | tot_accuracy | ||
) |
Does the entire forward and backward computation for this minbatch.
Returns total objective function over this minibatch. If tot_accuracy != NULL, outputs to that pointer the total accuracy.
Definition at line 46 of file nnet-update.cc.
References NnetUpdater::Backprop(), NnetUpdater::ComputeObjfAndDeriv(), NnetUpdater::FormatInput(), NnetUpdater::nnet_to_update_, and NnetUpdater::Propagate().
Referenced by kaldi::nnet2::ComputeNnetObjf(), and kaldi::nnet2::DoBackprop().
double ComputeForMinibatch | ( | const std::vector< NnetExample > & | data, |
Matrix< BaseFloat > * | formatted_data, | ||
double * | tot_accuracy | ||
) |
This version of ComputeForMinibatch is used when you have already called the function FormatNnetInput (defined below) to format your data as a single matrix.
This interface is provided because it can be more efficient to do this non-trivial CPU-based computation in a separate thread. formatted_data is an input but this function will destroy it, which is why it's a pointer.
Definition at line 63 of file nnet-update.cc.
References NnetUpdater::Backprop(), NnetUpdater::chunk_info_out_, Nnet::ComputeChunkInfo(), NnetUpdater::ComputeObjfAndDeriv(), NnetUpdater::forward_data_, Nnet::InputDim(), KALDI_ASSERT, Nnet::LeftContext(), NnetUpdater::nnet_, NnetUpdater::nnet_to_update_, MatrixBase< Real >::NumCols(), Nnet::NumComponents(), MatrixBase< Real >::NumRows(), NnetUpdater::Propagate(), and Nnet::RightContext().
|
protected |
Computes objective function and derivative at output layer, but does not do the backprop [for that, see Backprop()].
Returns objf summed over all samples (with their weights). If tot_accuracy != NULL, it will output to tot_accuracy the sum over all labels of all examples, of (correctly classified ? 0 : 1) * weight-of-label. This involves extra computation.
Definition at line 125 of file nnet-update.cc.
References CuMatrix< Real >::CompObjfAndDeriv(), NnetUpdater::ComputeTotAccuracy(), NnetUpdater::forward_data_, rnnlm::i, KALDI_ASSERT, KALDI_VLOG, NnetUpdater::nnet_, Nnet::NumComponents(), Nnet::OutputDim(), CuMatrix< Real >::Resize(), and kaldi::SameDim().
Referenced by NnetUpdater::ComputeForMinibatch().
|
private |
Definition at line 161 of file nnet-update.cc.
References CuMatrixBase< Real >::FindRowMaxId(), NnetUpdater::forward_data_, rnnlm::i, rnnlm::j, KALDI_ASSERT, NnetUpdater::nnet_, Nnet::NumComponents(), and CuMatrixBase< Real >::NumRows().
Referenced by NnetUpdater::ComputeObjfAndDeriv().
|
protected |
Formats the input as a single matrix and sets the size of forward_data_, and sets up chunk_info_out_.
Definition at line 35 of file nnet-update.cc.
References NnetUpdater::chunk_info_out_, Nnet::ComputeChunkInfo(), kaldi::nnet2::FormatNnetInput(), NnetUpdater::forward_data_, Nnet::LeftContext(), NnetUpdater::nnet_, Nnet::NumComponents(), and Nnet::RightContext().
Referenced by NnetUpdater::ComputeForMinibatch().
Definition at line 91 of file nnet-update.cc.
References NnetUpdater::forward_data_, KALDI_ASSERT, NnetUpdater::nnet_, and Nnet::NumComponents().
|
protected |
Definition at line 97 of file nnet-update.cc.
References Component::BackpropNeedsInput(), Component::BackpropNeedsOutput(), NnetUpdater::chunk_info_out_, NnetUpdater::forward_data_, kaldi::g_kaldi_verbose_level, Nnet::GetComponent(), KALDI_VLOG, kaldi::kTrans, NnetUpdater::nnet_, Nnet::NumComponents(), Component::Propagate(), and kaldi::TraceMatMat().
Referenced by NnetUpdater::ComputeForMinibatch().
|
friend |
Definition at line 98 of file nnet-update.h.
|
private |
Definition at line 106 of file nnet-update.h.
Referenced by NnetUpdater::Backprop(), NnetUpdater::ComputeForMinibatch(), NnetUpdater::FormatInput(), and NnetUpdater::Propagate().
Definition at line 108 of file nnet-update.h.
Referenced by NnetUpdater::Backprop(), NnetUpdater::ComputeForMinibatch(), NnetUpdater::ComputeObjfAndDeriv(), NnetUpdater::ComputeTotAccuracy(), NnetUpdater::FormatInput(), NnetUpdater::GetOutput(), and NnetUpdater::Propagate().
|
private |
Definition at line 103 of file nnet-update.h.
Referenced by NnetUpdater::Backprop(), NnetUpdater::ComputeForMinibatch(), NnetUpdater::ComputeObjfAndDeriv(), NnetUpdater::ComputeTotAccuracy(), NnetUpdater::FormatInput(), NnetUpdater::GetOutput(), and NnetUpdater::Propagate().
|
private |
Definition at line 104 of file nnet-update.h.
Referenced by NnetUpdater::Backprop(), and NnetUpdater::ComputeForMinibatch().
|
private |
Definition at line 105 of file nnet-update.h.