#include <nnet-computation.h>
Public Member Functions | |
ComputationRequest () | |
bool | NeedDerivatives () const |
returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true. More... | |
int32 | IndexForInput (const std::string &node_name) const |
Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More... | |
int32 | IndexForOutput (const std::string &node_name) const |
Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More... | |
void | Print (std::ostream &os) const |
This function is for printing info about the computation request in a human-readable way. More... | |
void | Read (std::istream &istream, bool binary) |
void | Write (std::ostream &ostream, bool binary) const |
bool | operator== (const ComputationRequest &other) const |
Public Attributes | |
std::vector< IoSpecification > | inputs |
std::vector< IoSpecification > | outputs |
bool | need_model_derivative |
if need_model_derivative is true, then we'll be doing either model training or model-derivative computation, so updatable components need to be backprop'd. More... | |
bool | store_component_stats |
you should set need_component_stats to true if you need the average-activation and average-derivative statistics stored by the StoreStats() functions of components/ such as Tanh, Sigmoid and Softmax. More... | |
MiscComputationInfo | misc_info |
misc_info is for extensibility to things that don't easily fit into the framework. More... | |
Definition at line 114 of file nnet-computation.h.
|
inline |
Definition at line 132 of file nnet-computation.h.
References MiscComputationInfo::operator==(), and MiscComputationInfo::Print().
int32 IndexForInput | ( | const std::string & | node_name | ) | const |
Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.
It is an error if >1 inputs have the same name.
Definition at line 54 of file nnet-computation.cc.
References rnnlm::i, ComputationRequest::inputs, and KALDI_ASSERT.
Referenced by Compiler::ComputeDerivNeeded().
int32 IndexForOutput | ( | const std::string & | node_name | ) | const |
Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.
It is an error if >1 inputs have the same name.
Definition at line 66 of file nnet-computation.cc.
References rnnlm::i, KALDI_ASSERT, and ComputationRequest::outputs.
Referenced by Compiler::ComputeDerivNeeded().
bool NeedDerivatives | ( | ) | const |
returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true.
Definition at line 29 of file nnet-computation.cc.
References rnnlm::i, ComputationRequest::inputs, KALDI_ERR, ComputationRequest::need_model_derivative, and ComputationRequest::outputs.
Referenced by Compiler::SetUpPrecomputedIndexes().
bool operator== | ( | const ComputationRequest & | other | ) | const |
Definition at line 1125 of file nnet-computation.cc.
References ComputationRequest::inputs, ComputationRequest::misc_info, ComputationRequest::need_model_derivative, ComputationRequest::outputs, and ComputationRequest::store_component_stats.
void Print | ( | std::ostream & | os | ) | const |
This function is for printing info about the computation request in a human-readable way.
Definition at line 1092 of file nnet-computation.cc.
References rnnlm::i, ComputationRequest::inputs, ComputationRequest::misc_info, ComputationRequest::need_model_derivative, ComputationRequest::outputs, MiscComputationInfo::Print(), and ComputationRequest::store_component_stats.
Referenced by kaldi::nnet3::CompileLoopedInternal(), CachingOptimizingCompiler::CompileNoShortcut(), kaldi::nnet3::UnitTestNnetCompile(), kaldi::nnet3::UnitTestNnetCompileLooped(), and kaldi::nnet3::UnitTestNnetCompileMulti().
void Read | ( | std::istream & | istream, |
bool | binary | ||
) |
Definition at line 1034 of file nnet-computation.cc.
References kaldi::nnet3::ExpectToken(), ComputationRequest::inputs, KALDI_ASSERT, ComputationRequest::need_model_derivative, ComputationRequest::outputs, kaldi::ReadBasicType(), and ComputationRequest::store_component_stats.
Referenced by ComputationCache::Read(), and kaldi::nnet3::UnitTestComputationRequestIo().
void Write | ( | std::ostream & | ostream, |
bool | binary | ||
) | const |
Definition at line 1063 of file nnet-computation.cc.
References ComputationRequest::inputs, ComputationRequest::need_model_derivative, ComputationRequest::outputs, ComputationRequest::store_component_stats, kaldi::WriteBasicType(), and kaldi::WriteToken().
Referenced by kaldi::nnet3::UnitTestComputationRequestIo().
std::vector<IoSpecification> inputs |
Definition at line 115 of file nnet-computation.h.
Referenced by kaldi::nnet3::computation_graph::AddInputToGraph(), kaldi::nnet3::AddTimeOffsetToComputationRequest(), DecodableNnetLoopedOnlineBase::AdvanceChunk(), DecodableNnetSimpleLooped::AdvanceChunk(), BatchedXvectorComputer::BatchedXvectorComputer(), Compiler::ComputeDerivNeeded(), kaldi::nnet3::ComputeExampleComputationRequestSimple(), kaldi::nnet3::ComputeSimpleNnetContextForShift(), kaldi::nnet3::CreateComputationRequestInternal(), Compiler::DeallocateMatrices(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::ExtrapolateComputationRequest(), kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::GetComputationRequest(), NnetBatchComputer::GetComputationRequest(), kaldi::nnet3::GetDiscriminativeComputationRequest(), ComputationRequest::IndexForInput(), ComputationRequest::NeedDerivatives(), ComputationRequest::operator==(), ComputationRequest::Print(), ComputationStepsComputer::ProcessInputOrOutputStep(), ComputationRequest::Read(), kaldi::nnet3::RequestIsDecomposable(), kaldi::nnet3::RunNnetComputation(), kaldi::nnet3::SetDerivTimesOptions(), kaldi::nnet3::UnitTestNnetCompileMulti(), kaldi::nnet3::UnitTestNnetCompute(), kaldi::nnet3::UnitTestNnetInputDerivatives(), kaldi::nnet3::UnitTestNnetModelDerivatives(), kaldi::nnet3::UnitTestNnetOptimizeWithOptions(), and ComputationRequest::Write().
MiscComputationInfo misc_info |
misc_info is for extensibility to things that don't easily fit into the framework.
Definition at line 130 of file nnet-computation.h.
Referenced by ComputationGraphBuilder::AddDependencies(), CachingOptimizingCompiler::CompileViaShortcut(), ComputationGraphBuilder::ComputeComputableInfo(), kaldi::nnet3::ComputeComputationGraph(), ComputationRequest::operator==(), ComputationRequest::Print(), kaldi::nnet3::RequestIsDecomposable(), and Compiler::SetUpPrecomputedIndexes().
bool need_model_derivative |
if need_model_derivative is true, then we'll be doing either model training or model-derivative computation, so updatable components need to be backprop'd.
Definition at line 121 of file nnet-computation.h.
Referenced by BatchedXvectorComputer::BatchedXvectorComputer(), Compiler::ComputeDerivNeeded(), kaldi::nnet3::ComputeExampleComputationRequestSimple(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::GetComputationRequest(), NnetBatchComputer::GetComputationRequest(), kaldi::nnet3::GetDiscriminativeComputationRequest(), ComputationRequest::NeedDerivatives(), ComputationRequest::operator==(), ComputationRequest::Print(), ComputationRequest::Read(), NnetComputation::Read(), kaldi::nnet3::RequestIsDecomposable(), kaldi::nnet3::RunNnetComputation(), kaldi::nnet3::UnitTestNnetCompileMulti(), kaldi::nnet3::UnitTestNnetInputDerivatives(), kaldi::nnet3::UnitTestNnetModelDerivatives(), ComputationRequest::Write(), and NnetComputation::Write().
std::vector<IoSpecification> outputs |
Definition at line 116 of file nnet-computation.h.
Referenced by kaldi::nnet3::computation_graph::AddOutputToGraph(), kaldi::nnet3::AddTimeOffsetToComputationRequest(), BatchedXvectorComputer::BatchedXvectorComputer(), Compiler::ComputeDerivNeeded(), kaldi::nnet3::ComputeExampleComputationRequestSimple(), kaldi::nnet3::ComputeSimpleNnetContextForShift(), kaldi::nnet3::CreateComputationRequestInternal(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::GetChainComputationRequest(), ComputationGraphBuilder::GetComputableInfo(), kaldi::nnet3::GetComputationRequest(), NnetBatchComputer::GetComputationRequest(), kaldi::nnet3::GetDiscriminativeComputationRequest(), ComputationRequest::IndexForOutput(), kaldi::nnet3::MaxOutputTimeInRequest(), ComputationRequest::NeedDerivatives(), ComputationRequest::operator==(), ComputationRequest::Print(), ComputationStepsComputer::ProcessInputOrOutputStep(), ComputationRequest::Read(), kaldi::nnet3::RequestIsDecomposable(), kaldi::nnet3::RunNnetComputation(), kaldi::nnet3::SetDerivTimesOptions(), kaldi::nnet3::UnitTestNnetCompileMulti(), kaldi::nnet3::UnitTestNnetCompute(), kaldi::nnet3::UnitTestNnetInputDerivatives(), kaldi::nnet3::UnitTestNnetModelDerivatives(), kaldi::nnet3::UnitTestNnetOptimizeWithOptions(), and ComputationRequest::Write().
bool store_component_stats |
you should set need_component_stats to true if you need the average-activation and average-derivative statistics stored by the StoreStats() functions of components/ such as Tanh, Sigmoid and Softmax.
Definition at line 126 of file nnet-computation.h.
Referenced by BatchedXvectorComputer::BatchedXvectorComputer(), kaldi::nnet3::ComputeExampleComputationRequestSimple(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::GetComputationRequest(), NnetBatchComputer::GetComputationRequest(), kaldi::nnet3::GetDiscriminativeComputationRequest(), ComputationRequest::operator==(), ComputationRequest::Print(), ComputationRequest::Read(), kaldi::nnet3::RequestIsDecomposable(), kaldi::nnet3::RunNnetComputation(), kaldi::nnet3::UnitTestNnetCompileMulti(), and ComputationRequest::Write().