class NnetComputer is responsible for executing the computation described in the "computation" object. More...
#include <nnet-compute.h>
Classes | |
struct | CommandDebugInfo |
Public Member Functions | |
NnetComputer (const NnetComputeOptions &options, const NnetComputation &computation, const Nnet &nnet, Nnet *nnet_to_update) | |
Constructor. More... | |
NnetComputer (const NnetComputeOptions &options, const NnetComputation &computation, Nnet *nnet, Nnet *nnet_to_update) | |
This version of the constructor accepts a pointer to 'nnet' instead of a const reference. More... | |
NnetComputer (const NnetComputer &other) | |
Copy constructor. More... | |
void | AcceptInput (const std::string &node_name, CuMatrix< BaseFloat > *input) |
e.g. More... | |
void | AcceptInputs (const Nnet &nnet, const std::vector< NnetIo > &io) |
This convenience function calls AcceptInput() in turn on all the inputs in the training example. More... | |
void | Run () |
This does either the forward or backward computation, depending when it is called (in a typical computation, the first time you call this it will do the forward computation; then you'll take the outputs and provide derivatives; and the second time you call it, it will do the backward computation. More... | |
const CuMatrixBase< BaseFloat > & | GetOutput (const std::string &node_name) |
void | GetOutputDestructive (const std::string &output_name, CuMatrix< BaseFloat > *output) |
~NnetComputer () | |
Private Member Functions | |
void | Init () |
void | ExecuteCommand () |
int32 | GetIoMatrixIndex (const std::string &node_name, bool is_output) |
void | CheckNoPendingIo () |
CuSubMatrix< BaseFloat > | GetSubMatrix (int32 submatrix_index) |
void | GetPointers (int32 indexes_multi_index, int32 num_cols, CuArray< BaseFloat *> *pointers) |
void | GetPointers (int32 indexes_multi_index, int32 num_cols, CuArray< const BaseFloat *> *pointers) |
void | DebugBeforeExecute (int32 command, CommandDebugInfo *info) |
void | DebugAfterExecute (int32 command, const CommandDebugInfo &info, double command_execution_time) |
void | SaveMemo (int32 memo_index, const Component &c, void *memo) |
void * | GetMemo (int32 memo_index) |
NnetComputer & | operator= (const NnetComputer &other) |
Static Private Member Functions | |
static BaseFloat | MatrixStddev (const CuMatrixBase< BaseFloat > &m) |
static BaseFloat | ParameterStddev (const Component &c) |
Private Attributes | |
const NnetComputeOptions & | options_ |
const NnetComputation & | computation_ |
const Nnet & | nnet_ |
int32 | program_counter_ |
std::vector< int32 > | pending_commands_ |
Nnet * | nnet_to_store_stats_ |
Nnet * | nnet_to_update_ |
bool | debug_ |
std::vector< CommandAttributes > | command_attributes_ |
std::vector< std::string > | submatrix_strings_ |
std::vector< std::string > | command_strings_ |
std::vector< CuMatrix< BaseFloat > > | matrices_ |
std::vector< void * > | memos_ |
std::vector< CuCompressedMatrixBase * > | compressed_matrices_ |
class NnetComputer is responsible for executing the computation described in the "computation" object.
You call in sequence, the constructor, then AcceptInput() [or AcceptInputs()], then Run(), then GetOutput() [and if applicable, AcceptOutputDeriv], then if there is a backward computation, Run() [then, if applicable, GetInputDeriv()].
Definition at line 59 of file nnet-compute.h.
NnetComputer | ( | const NnetComputeOptions & | options, |
const NnetComputation & | computation, | ||
const Nnet & | nnet, | ||
Nnet * | nnet_to_update | ||
) |
Constructor.
nnet_to_update will be NULL if you are not doing model update or model-derivative computation. You must call computation.ComputeCudaIndexes() before calling this function.
Caution: there is another constructor that takes a pointer for 'nnet', be careful not to mix these up.
Definition at line 28 of file nnet-compute.cc.
References NnetComputer::Init().
NnetComputer | ( | const NnetComputeOptions & | options, |
const NnetComputation & | computation, | ||
Nnet * | nnet, | ||
Nnet * | nnet_to_update | ||
) |
This version of the constructor accepts a pointer to 'nnet' instead of a const reference.
The difference is that this version will, for storing statistics (the StoreStats() function of class Component), use 'nnet' instead of 'nnet_to_update' (if specified).
Definition at line 38 of file nnet-compute.cc.
References NnetComputer::Init().
NnetComputer | ( | const NnetComputer & | other | ) |
Copy constructor.
May not be used if memos are stored with this object (which is only a possibility if backprop will take place, and in these situations you won't normally be wanting to use the copy constructor anyway; the copy constructor is more useful for things like RNNLM lattice rescoring).
Definition at line 189 of file nnet-compute.cc.
References KALDI_ERR, and NnetComputer::memos_.
~NnetComputer | ( | ) |
Definition at line 680 of file nnet-compute.cc.
References NnetComputer::compressed_matrices_, and rnnlm::i.
e.g.
AcceptInput ("input", &input_mat), or for derivatives w.r.t. the output, AcceptInput("output", output_deriv_mat). Will crash if there is no input or output node with the given name. This function is destructive of "input" as it takes it using the Swap function of CuMatrix. Must have the same number of rows as the corresponding input described in the ComputationRequest e.g. the indexes.size() in the corresponding IoSpecification.
Definition at line 547 of file nnet-compute.cc.
References NnetComputer::computation_, NnetComputer::GetIoMatrixIndex(), KALDI_ERR, kaldi::kDefaultStride, kaldi::kStrideEqualNumCols, kaldi::kUndefined, NnetComputation::matrices, NnetComputer::matrices_, NnetComputation::MatrixInfo::num_cols, NnetComputation::MatrixInfo::num_rows, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), CuMatrix< Real >::Resize(), CuMatrixBase< Real >::Stride(), and NnetComputation::MatrixInfo::stride_type.
Referenced by NnetComputer::AcceptInputs(), DecodableNnetLoopedOnlineBase::AdvanceChunk(), DecodableNnetSimpleLooped::AdvanceChunk(), NnetBatchComputer::Compute(), kaldi::nnet3::ComputeObjectiveFunction(), BatchedXvectorComputer::ComputeOneBatch(), DecodableNnetSimple::DoNnetComputation(), NnetDiscriminativeComputeObjf::ProcessOutputs(), NnetChainTrainer::ProcessOutputs(), NnetChainComputeProb::ProcessOutputs(), NnetDiscriminativeTrainer::ProcessOutputs(), kaldi::nnet3::RunNnetComputation(), kaldi::nnet3::UnitTestNnetCompute(), kaldi::nnet3::UnitTestNnetInputDerivatives(), kaldi::nnet3::UnitTestNnetModelDerivatives(), and kaldi::nnet3::UnitTestNnetOptimizeWithOptions().
This convenience function calls AcceptInput() in turn on all the inputs in the training example.
It needs "nnet" only in order to distinguish inputs from outputs.
Definition at line 663 of file nnet-compute.cc.
References NnetComputer::AcceptInput(), CuMatrixBase< Real >::CopyFromGeneralMat(), NnetIo::features, Nnet::GetNodeIndex(), rnnlm::i, Nnet::IsInputNode(), KALDI_ERR, kaldi::kUndefined, NnetIo::name, GeneralMatrix::NumCols(), and GeneralMatrix::NumRows().
Referenced by NnetLdaStatsAccumulator::AccStats(), NnetComputerFromEg::Compute(), NnetDiscriminativeComputeObjf::Compute(), NnetChainComputeProb::Compute(), NnetComputeProb::Compute(), NnetDiscriminativeTrainer::Train(), NnetChainTrainer::TrainInternal(), NnetTrainer::TrainInternal(), NnetChainTrainer::TrainInternalBackstitch(), and NnetTrainer::TrainInternalBackstitch().
|
private |
Definition at line 597 of file nnet-compute.cc.
References NnetComputation::commands, NnetComputer::computation_, Nnet::GetNodeName(), rnnlm::i, kaldi::nnet3::kAcceptInput, KALDI_ERR, kaldi::nnet3::kProvideOutput, NnetComputer::nnet_, NnetComputer::pending_commands_, and NnetComputer::program_counter_.
Referenced by NnetComputer::Run().
|
private |
Definition at line 116 of file nnet-compute.cc.
References NnetComputation::Command::arg1, NnetComputer::command_attributes_, NnetComputer::command_strings_, NnetComputation::Command::command_type, NnetComputation::commands, NnetComputer::CommandDebugInfo::components_parameter_stddev, NnetComputer::computation_, Nnet::GetComponent(), Nnet::GetComponentName(), NnetComputer::GetSubMatrix(), rnnlm::i, NnetComputation::IsWholeMatrix(), KALDI_ASSERT, KALDI_LOG, kaldi::nnet3::kBackprop, kaldi::nnet3::kUpdatableComponent, NnetComputer::matrices_, NnetComputer::CommandDebugInfo::matrices_written_stddevs, NnetComputer::MatrixStddev(), NnetComputer::nnet_, NnetComputer::ParameterStddev(), Component::Properties(), NnetComputer::CommandDebugInfo::submatrices_written_stddevs, and NnetComputer::submatrix_strings_.
Referenced by NnetComputer::Run().
|
private |
Definition at line 82 of file nnet-compute.cc.
References NnetComputation::Command::arg1, NnetComputer::command_attributes_, NnetComputation::Command::command_type, NnetComputation::commands, NnetComputer::CommandDebugInfo::components_parameter_stddev, NnetComputer::computation_, Nnet::GetComponent(), NnetComputer::GetSubMatrix(), rnnlm::i, NnetComputation::IsWholeMatrix(), kaldi::nnet3::kBackprop, kaldi::nnet3::kUpdatableComponent, NnetComputer::matrices_, NnetComputer::CommandDebugInfo::matrices_written_stddevs, NnetComputer::MatrixStddev(), NnetComputer::nnet_, NnetComputer::ParameterStddev(), Component::Properties(), and NnetComputer::CommandDebugInfo::submatrices_written_stddevs.
Referenced by NnetComputer::Run().
|
private |
Definition at line 210 of file nnet-compute.cc.
References CuMatrixBase< Real >::AddMat(), CuMatrixBase< Real >::AddRowRanges(), CuMatrixBase< Real >::AddRows(), CuMatrixBase< Real >::AddToRows(), NnetComputation::Command::alpha, NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::arg3, NnetComputation::Command::arg4, NnetComputation::Command::arg5, NnetComputation::Command::arg6, NnetComputation::Command::arg7, Component::Backprop(), NnetComputer::command_strings_, NnetComputation::Command::command_type, NnetComputation::commands, NnetComputation::component_precomputed_indexes, NnetComputer::compressed_matrices_, NnetComputer::computation_, CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::CopyRows(), CuCompressedMatrixBase::CopyToMat(), CuMatrixBase< Real >::CopyToRows(), NnetComputer::debug_, Component::DeleteMemo(), NnetComputation::GetCommandStrings(), Nnet::GetComponent(), Nnet::GetComponentName(), NnetComputer::GetMemo(), NnetComputer::GetPointers(), NnetComputer::GetSubMatrix(), NnetComputation::indexes_cuda, NnetComputation::indexes_ranges_cuda, kaldi::nnet3::kAddRowRanges, kaldi::nnet3::kAddRows, kaldi::nnet3::kAddRowsMulti, kaldi::nnet3::kAddToRowsMulti, KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_WARN, kaldi::nnet3::kAllocMatrix, kaldi::nnet3::kBackprop, kaldi::nnet3::kBackpropNoModelUpdate, kaldi::nnet3::kCompressMatrix, kaldi::nnet3::kCopyRows, kaldi::nnet3::kCopyRowsMulti, kaldi::nnet3::kCopyToRowsMulti, kaldi::nnet3::kDeallocMatrix, kaldi::nnet3::kDecompressMatrix, kaldi::nnet3::kGotoLabel, kaldi::nnet3::kMatrixAdd, kaldi::nnet3::kMatrixCopy, kaldi::nnet3::kNoOperation, kaldi::nnet3::kNoOperationLabel, kaldi::nnet3::kNoOperationMarker, kaldi::nnet3::kNoOperationPermanent, kaldi::nnet3::kPropagate, kaldi::nnet3::kSetConst, kaldi::nnet3::kSwapMatrix, kaldi::kUndefined, kaldi::nnet3::kUpdatableComponent, NnetComputation::matrices, NnetComputer::matrices_, NnetComputation::need_model_derivative, kaldi::NewCuCompressedMatrix(), NnetComputer::nnet_, NnetComputer::nnet_to_store_stats_, NnetComputer::nnet_to_update_, CuCompressedMatrixBase::NumCols(), CuMatrixBase< Real >::NumCols(), CuCompressedMatrixBase::NumRows(), NVTX_RANGE, NnetComputer::program_counter_, Component::Propagate(), Component::Properties(), NnetComputer::SaveMemo(), CuMatrixBase< Real >::Scale(), CuMatrixBase< Real >::Set(), CuMatrixBase< Real >::SetZero(), Component::StoreStats(), and NnetComputation::submatrices.
Referenced by NnetComputer::Run().
Definition at line 620 of file nnet-compute.cc.
References NnetComputation::commands, NnetComputer::computation_, Nnet::GetNodeIndex(), rnnlm::i, NnetComputation::IsWholeMatrix(), kaldi::nnet3::kAcceptInput, KALDI_ERR, kaldi::nnet3::kNoOperationMarker, kaldi::nnet3::kProvideOutput, NnetComputer::nnet_, NnetComputer::pending_commands_, NnetComputer::program_counter_, and NnetComputation::submatrices.
Referenced by NnetComputer::AcceptInput(), NnetComputer::GetOutput(), and NnetComputer::GetOutputDestructive().
|
inlineprivate |
Definition at line 176 of file nnet-compute.cc.
References KALDI_ERR, and NnetComputer::memos_.
Referenced by NnetComputer::ExecuteCommand().
const CuMatrixBase< BaseFloat > & GetOutput | ( | const std::string & | node_name | ) |
Definition at line 578 of file nnet-compute.cc.
References NnetComputer::GetIoMatrixIndex(), KALDI_ASSERT, and NnetComputer::matrices_.
Referenced by kaldi::nnet3::ComputeObjectiveFunction(), NnetDiscriminativeComputeObjf::ProcessOutputs(), NnetChainTrainer::ProcessOutputs(), NnetChainComputeProb::ProcessOutputs(), NnetDiscriminativeTrainer::ProcessOutputs(), NnetComputeProb::ProcessOutputs(), kaldi::nnet3::UnitTestNnetCompute(), kaldi::nnet3::UnitTestNnetInputDerivatives(), kaldi::nnet3::UnitTestNnetModelDerivatives(), and kaldi::nnet3::UnitTestNnetOptimizeWithOptions().
Definition at line 587 of file nnet-compute.cc.
References NnetComputer::GetIoMatrixIndex(), KALDI_ASSERT, and NnetComputer::matrices_.
Referenced by DecodableNnetLoopedOnlineBase::AdvanceChunk(), DecodableNnetSimpleLooped::AdvanceChunk(), NnetBatchComputer::Compute(), BatchedXvectorComputer::ComputeOneBatch(), DecodableNnetSimple::DoNnetComputation(), and kaldi::nnet3::RunNnetComputation().
|
private |
Definition at line 459 of file nnet-compute.cc.
References NnetComputer::computation_, CuArray< T >::CopyFromVec(), CuMatrixBase< Real >::Data(), NnetComputer::GetSubMatrix(), rnnlm::i, NnetComputation::indexes_multi, KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), kaldi::RandInt(), and CuMatrixBase< Real >::Stride().
Referenced by NnetComputer::ExecuteCommand(), and NnetComputer::GetPointers().
|
private |
Definition at line 505 of file nnet-compute.cc.
References NnetComputer::GetPointers().
|
private |
Definition at line 449 of file nnet-compute.cc.
References NnetComputer::computation_, KALDI_PARANOID_ASSERT, NnetComputer::matrices_, and NnetComputation::submatrices.
Referenced by NnetComputer::DebugAfterExecute(), NnetComputer::DebugBeforeExecute(), NnetComputer::ExecuteCommand(), and NnetComputer::GetPointers().
|
private |
Definition at line 48 of file nnet-compute.cc.
References NnetComputer::command_attributes_, NnetComputer::command_strings_, NnetComputer::computation_, kaldi::nnet3::ComputeCommandAttributes(), NnetComputeOptions::debug, NnetComputer::debug_, NnetComputation::GetCommandStrings(), NnetComputation::GetSubmatrixStrings(), kaldi::GetVerboseLevel(), NnetComputation::indexes, NnetComputation::indexes_cuda, NnetComputation::indexes_ranges, NnetComputation::indexes_ranges_cuda, ComputationVariables::Init(), KALDI_ASSERT, KALDI_LOG, NnetComputation::matrices, NnetComputer::matrices_, NnetComputer::nnet_, NnetComputer::options_, and NnetComputer::submatrix_strings_.
Referenced by NnetComputer::NnetComputer().
|
staticprivate |
Definition at line 68 of file nnet-compute.cc.
References kaldi::kTrans, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), and kaldi::TraceMatMat().
Referenced by NnetComputer::DebugAfterExecute(), and NnetComputer::DebugBeforeExecute().
|
private |
Definition at line 75 of file nnet-compute.cc.
References UpdatableComponent::DotProduct(), KALDI_ASSERT, and UpdatableComponent::NumParameters().
Referenced by NnetComputer::DebugAfterExecute(), and NnetComputer::DebugBeforeExecute().
void Run | ( | ) |
This does either the forward or backward computation, depending when it is called (in a typical computation, the first time you call this it will do the forward computation; then you'll take the outputs and provide derivatives; and the second time you call it, it will do the backward computation.
There used to be two separate functions Forward() and Backward().
Definition at line 512 of file nnet-compute.cc.
References NnetComputer::CheckNoPendingIo(), NnetComputation::commands, NnetComputer::computation_, NnetComputer::debug_, NnetComputer::DebugAfterExecute(), NnetComputer::DebugBeforeExecute(), Timer::Elapsed(), NnetComputer::ExecuteCommand(), kaldi::nnet3::kAcceptInput, KALDI_ERR, kaldi::nnet3::kProvideOutput, NnetComputer::nnet_, NVTX_RANGE, NnetComputation::Print(), and NnetComputer::program_counter_.
Referenced by DecodableNnetLoopedOnlineBase::AdvanceChunk(), DecodableNnetSimpleLooped::AdvanceChunk(), NnetDiscriminativeComputeObjf::Compute(), NnetChainComputeProb::Compute(), NnetComputeProb::Compute(), NnetBatchComputer::Compute(), BatchedXvectorComputer::ComputeOneBatch(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::RunNnetComputation(), NnetDiscriminativeTrainer::Train(), NnetChainTrainer::TrainInternal(), NnetTrainer::TrainInternal(), NnetChainTrainer::TrainInternalBackstitch(), NnetTrainer::TrainInternalBackstitch(), kaldi::nnet3::UnitTestNnetCompute(), kaldi::nnet3::UnitTestNnetInputDerivatives(), kaldi::nnet3::UnitTestNnetModelDerivatives(), and kaldi::nnet3::UnitTestNnetOptimizeWithOptions().
Definition at line 163 of file nnet-compute.cc.
References Component::DeleteMemo(), and NnetComputer::memos_.
Referenced by NnetComputer::ExecuteCommand().
|
private |
Definition at line 153 of file nnet-compute.h.
Referenced by NnetComputer::DebugAfterExecute(), NnetComputer::DebugBeforeExecute(), and NnetComputer::Init().
|
private |
Definition at line 157 of file nnet-compute.h.
Referenced by NnetComputer::DebugAfterExecute(), NnetComputer::ExecuteCommand(), and NnetComputer::Init().
|
private |
Definition at line 173 of file nnet-compute.h.
Referenced by NnetComputer::ExecuteCommand(), and NnetComputer::~NnetComputer().
|
private |
Definition at line 133 of file nnet-compute.h.
Referenced by NnetComputer::AcceptInput(), NnetComputer::CheckNoPendingIo(), NnetComputer::DebugAfterExecute(), NnetComputer::DebugBeforeExecute(), NnetComputer::ExecuteCommand(), NnetComputer::GetIoMatrixIndex(), NnetComputer::GetPointers(), NnetComputer::GetSubMatrix(), NnetComputer::Init(), and NnetComputer::Run().
|
private |
Definition at line 151 of file nnet-compute.h.
Referenced by NnetComputer::ExecuteCommand(), NnetComputer::Init(), and NnetComputer::Run().
Definition at line 160 of file nnet-compute.h.
Referenced by NnetComputer::AcceptInput(), NnetComputer::DebugAfterExecute(), NnetComputer::DebugBeforeExecute(), NnetComputer::ExecuteCommand(), NnetComputer::GetOutput(), NnetComputer::GetOutputDestructive(), NnetComputer::GetSubMatrix(), and NnetComputer::Init().
|
private |
Definition at line 165 of file nnet-compute.h.
Referenced by NnetComputer::GetMemo(), NnetComputer::NnetComputer(), and NnetComputer::SaveMemo().
|
private |
Definition at line 134 of file nnet-compute.h.
Referenced by NnetComputer::CheckNoPendingIo(), NnetComputer::DebugAfterExecute(), NnetComputer::DebugBeforeExecute(), NnetComputer::ExecuteCommand(), NnetComputer::GetIoMatrixIndex(), NnetComputer::Init(), and NnetComputer::Run().
|
private |
Definition at line 146 of file nnet-compute.h.
Referenced by NnetComputer::ExecuteCommand().
|
private |
Definition at line 150 of file nnet-compute.h.
Referenced by NnetComputer::ExecuteCommand().
|
private |
Definition at line 132 of file nnet-compute.h.
Referenced by NnetComputer::Init().
|
private |
Definition at line 141 of file nnet-compute.h.
Referenced by NnetComputer::CheckNoPendingIo(), and NnetComputer::GetIoMatrixIndex().
|
private |
Definition at line 136 of file nnet-compute.h.
Referenced by NnetComputer::CheckNoPendingIo(), NnetComputer::ExecuteCommand(), NnetComputer::GetIoMatrixIndex(), and NnetComputer::Run().
|
private |
Definition at line 155 of file nnet-compute.h.
Referenced by NnetComputer::DebugAfterExecute(), and NnetComputer::Init().