All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
ComputationRequest Struct Reference

#include <nnet-computation.h>

Collaboration diagram for ComputationRequest:

Public Member Functions

 ComputationRequest ()
 
bool NeedDerivatives () const
 returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true. More...
 
int32 IndexForInput (const std::string &node_name) const
 Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More...
 
int32 IndexForOutput (const std::string &node_name) const
 Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More...
 
void Print (std::ostream &os) const
 This function is for printing info about the computation request in a human-readable way. More...
 
void Read (std::istream &istream, bool binary)
 
void Write (std::ostream &ostream, bool binary) const
 
bool operator== (const ComputationRequest &other) const
 

Public Attributes

std::vector< IoSpecificationinputs
 
std::vector< IoSpecificationoutputs
 
bool need_model_derivative
 if need_model_derivative is true, then we'll be doing either model training or model-derivative computation, so updatable components need to be backprop'd. More...
 
bool store_component_stats
 you should set need_component_stats to true if you need the average-activation and average-derivative statistics stored by the StoreStats() functions of components/ such as Tanh, Sigmoid and Softmax. More...
 
MiscComputationInfo misc_info
 misc_info is for extensibility to things that don't easily fit into the framework. More...
 

Detailed Description

Definition at line 114 of file nnet-computation.h.

Constructor & Destructor Documentation

ComputationRequest ( )
inline

Definition at line 132 of file nnet-computation.h.

132  : need_model_derivative(false),
133  store_component_stats(false) { }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...

Member Function Documentation

int32 IndexForInput ( const std::string &  node_name) const

Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.

It is an error if >1 inputs have the same name.

Definition at line 54 of file nnet-computation.cc.

References rnnlm::i, ComputationRequest::inputs, and KALDI_ASSERT.

Referenced by Compiler::ComputeDerivNeeded().

55  {
56  int32 ans = -1;
57  for (size_t i = 0; i < inputs.size(); i++) {
58  if (inputs[i].name == node_name) {
59  KALDI_ASSERT(ans == -1 && "Two inputs with the same name");
60  ans = i;
61  }
62  }
63  return ans;
64 }
std::vector< IoSpecification > inputs
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 IndexForOutput ( const std::string &  node_name) const

Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.

It is an error if >1 inputs have the same name.

Definition at line 66 of file nnet-computation.cc.

References rnnlm::i, KALDI_ASSERT, and ComputationRequest::outputs.

Referenced by Compiler::ComputeDerivNeeded().

67  {
68  int32 ans = -1;
69  for (size_t i = 0; i < outputs.size(); i++) {
70  if (outputs[i].name == node_name) {
71  KALDI_ASSERT(ans == -1 && "Two inputs with the same name");
72  ans = i;
73  }
74  }
75  return ans;
76 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< IoSpecification > outputs
bool NeedDerivatives ( ) const

returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true.

Definition at line 29 of file nnet-computation.cc.

References rnnlm::i, ComputationRequest::inputs, KALDI_ERR, ComputationRequest::need_model_derivative, and ComputationRequest::outputs.

Referenced by Compiler::SetUpPrecomputedIndexes().

29  {
30  bool ans = false;
32  ans = true;
33  for (size_t i = 0; i < inputs.size(); i++) {
34  if (inputs[i].has_deriv) { // derivative requested for this input
35  ans = true;
36  break;
37  }
38  }
39  if (ans) {
40  // check that the output actually provides a derivative, else the
41  // request could not be meaningfully satisfied.
42  size_t i;
43  for (i = 0; i < outputs.size(); i++)
44  if (outputs[i].has_deriv)
45  break;
46  if (i == outputs.size()) {
47  KALDI_ERR << "You requested model derivatives or input derivatives, but "
48  << "provide no derivatives at the output.";
49  }
50  }
51  return ans;
52 }
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
std::vector< IoSpecification > inputs
#define KALDI_ERR
Definition: kaldi-error.h:127
std::vector< IoSpecification > outputs
bool operator== ( const ComputationRequest other) const

Definition at line 1084 of file nnet-computation.cc.

References ComputationRequest::inputs, ComputationRequest::misc_info, ComputationRequest::need_model_derivative, ComputationRequest::outputs, and ComputationRequest::store_component_stats.

1084  {
1085  // rely on the std::vector's default implementation of ==, which in turn
1086  // relies on the == operator of class IoSpecification.
1087  return inputs == other.inputs && outputs == other.outputs &&
1088  need_model_derivative == other.need_model_derivative &&
1089  store_component_stats == other.store_component_stats &&
1090  misc_info == other.misc_info;
1091 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
MiscComputationInfo misc_info
misc_info is for extensibility to things that don't easily fit into the framework.
std::vector< IoSpecification > inputs
std::vector< IoSpecification > outputs
void Print ( std::ostream &  os) const

This function is for printing info about the computation request in a human-readable way.

Definition at line 1051 of file nnet-computation.cc.

References rnnlm::i, and NnetComputation::need_model_derivative.

Referenced by kaldi::nnet3::CompileLoopedInternal(), CachingOptimizingCompiler::CompileNoShortcut(), ComputationGraphBuilder::ExplainWhyAllOutputsNotComputable(), kaldi::nnet3::UnitTestNnetCompile(), kaldi::nnet3::UnitTestNnetCompileLooped(), and kaldi::nnet3::UnitTestNnetCompileMulti().

1051  {
1052  os << " # Computation request:\n";
1053  for (size_t i = 0; i < inputs.size(); i++) {
1054  os << "input-" << i << ": ";
1055  inputs[i].Print(os);
1056  }
1057  for (size_t i = 0; i < outputs.size(); i++) {
1058  os << "output-" << i << ": ";
1059  outputs[i].Print(os);
1060  }
1061  os << "need-model-derivative: " <<
1062  (need_model_derivative ? "true\n" : "false\n");
1063  os << "store-component-stats: " <<
1064  (store_component_stats ? "true\n" : "false\n");
1065  misc_info.Print(os);
1066 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
MiscComputationInfo misc_info
misc_info is for extensibility to things that don't easily fit into the framework.
std::vector< IoSpecification > inputs
std::vector< IoSpecification > outputs
void Print(std::ostream &os) const
void Read ( std::istream &  istream,
bool  binary 
)

Definition at line 993 of file nnet-computation.cc.

References kaldi::nnet3::ExpectToken(), KALDI_ASSERT, NnetComputation::need_model_derivative, and kaldi::ReadBasicType().

Referenced by CachingOptimizingCompiler::ReadCache(), and kaldi::nnet3::UnitTestComputationRequestIo().

993  {
994  ExpectToken(is, binary, "<ComputationRequest>");
995  size_t num_inputs;
996  ExpectToken(is, binary, "<NumInputs>");
997  ReadBasicType(is, binary, &num_inputs);
998  KALDI_ASSERT(num_inputs >= 0);
999  inputs.resize(num_inputs);
1000  ExpectToken(is, binary, "<Inputs>");
1001  for (size_t c = 0; c < num_inputs; c++) {
1002  inputs[c].Read(is, binary);
1003  }
1004 
1005  size_t num_outputs;
1006  ExpectToken(is, binary, "<NumOutputs>");
1007  ReadBasicType(is, binary, &num_outputs);
1008  KALDI_ASSERT(num_outputs >= 0);
1009  outputs.resize(num_outputs);
1010  ExpectToken(is, binary, "<Outputs>");
1011  for (size_t c = 0; c < num_outputs; c++) {
1012  outputs[c].Read(is, binary);
1013  }
1014 
1015  ExpectToken(is, binary, "<NeedModelDerivative>");
1016  ReadBasicType(is, binary, &need_model_derivative);
1017  ExpectToken(is, binary, "<StoreComponentStats>");
1018  ReadBasicType(is, binary, &store_component_stats);
1019  ExpectToken(is, binary, "</ComputationRequest>");
1020 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
std::vector< IoSpecification > inputs
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< IoSpecification > outputs
void Write ( std::ostream &  ostream,
bool  binary 
) const

Definition at line 1022 of file nnet-computation.cc.

References NnetComputation::need_model_derivative, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by kaldi::nnet3::UnitTestComputationRequestIo().

1022  {
1023  WriteToken(os, binary, "<ComputationRequest>");
1024  if (!binary) os << std::endl;
1025  WriteToken(os, binary, "<NumInputs>");
1026  WriteBasicType(os, binary, inputs.size());
1027  if (!binary) os << std::endl;
1028  WriteToken(os, binary, "<Inputs>");
1029  for (size_t c = 0; c < inputs.size(); c++) {
1030  inputs[c].Write(os, binary);
1031  }
1032  if (!binary) os << std::endl;
1033 
1034  WriteToken(os, binary, "<NumOutputs>");
1035  WriteBasicType(os, binary, outputs.size());
1036  if (!binary) os << std::endl;
1037  WriteToken(os, binary, "<Outputs>");
1038  for (size_t c = 0; c < outputs.size(); c++) {
1039  outputs[c].Write(os, binary);
1040  }
1041  if (!binary) os << std::endl;
1042 
1043  WriteToken(os, binary, "<NeedModelDerivative>");
1045  WriteToken(os, binary, "<StoreComponentStats>");
1047  WriteToken(os, binary, "</ComputationRequest>");
1048  if (!binary) os << std::endl;
1049 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
std::vector< IoSpecification > inputs
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
std::vector< IoSpecification > outputs
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

bool store_component_stats

you should set need_component_stats to true if you need the average-activation and average-derivative statistics stored by the StoreStats() functions of components/ such as Tanh, Sigmoid and Softmax.

Definition at line 126 of file nnet-computation.h.

Referenced by kaldi::nnet3::ComputeExampleComputationRequestSimple(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::GetComputationRequest(), kaldi::nnet3::GetDiscriminativeComputationRequest(), ComputationRequest::operator==(), kaldi::nnet3::RequestIsDecomposable(), kaldi::nnet3::RunNnetComputation(), and kaldi::nnet3::UnitTestNnetCompileMulti().


The documentation for this struct was generated from the following files: