All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
ComputationRequest Struct Reference

#include <nnet-computation.h>

Collaboration diagram for ComputationRequest:

Public Member Functions

 ComputationRequest ()
 
bool NeedDerivatives () const
 returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true. More...
 
int32 IndexForInput (const std::string &node_name) const
 Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More...
 
int32 IndexForOutput (const std::string &node_name) const
 Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More...
 
void Print (std::ostream &os) const
 This function is for printing info about the computation request in a human-readable way. More...
 
void Read (std::istream &istream, bool binary)
 
void Write (std::ostream &ostream, bool binary) const
 
bool operator== (const ComputationRequest &other) const
 

Public Attributes

std::vector< IoSpecificationinputs
 
std::vector< IoSpecificationoutputs
 
bool need_model_derivative
 if need_model_derivative is true, then we'll be doing either model training or model-derivative computation, so updatable components need to be backprop'd. More...
 
bool store_component_stats
 you should set need_component_stats to true if you need the average-activation and average-derivative statistics stored by the StoreStats() functions of components/ such as Tanh, Sigmoid and Softmax. More...
 
MiscComputationInfo misc_info
 misc_info is for extensibility to things that don't easily fit into the framework. More...
 

Detailed Description

Definition at line 114 of file nnet-computation.h.

Constructor & Destructor Documentation

ComputationRequest ( )
inline

Definition at line 132 of file nnet-computation.h.

132  : need_model_derivative(false),
133  store_component_stats(false) { }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...

Member Function Documentation

int32 IndexForInput ( const std::string &  node_name) const

Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.

It is an error if >1 inputs have the same name.

Definition at line 53 of file nnet-computation.cc.

References rnnlm::i, ComputationRequest::inputs, and KALDI_ASSERT.

Referenced by Compiler::ComputeDerivNeeded().

54  {
55  int32 ans = -1;
56  for (size_t i = 0; i < inputs.size(); i++) {
57  if (inputs[i].name == node_name) {
58  KALDI_ASSERT(ans == -1 && "Two inputs with the same name");
59  ans = i;
60  }
61  }
62  return ans;
63 }
std::vector< IoSpecification > inputs
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 IndexForOutput ( const std::string &  node_name) const

Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.

It is an error if >1 inputs have the same name.

Definition at line 65 of file nnet-computation.cc.

References rnnlm::i, KALDI_ASSERT, and ComputationRequest::outputs.

Referenced by Compiler::ComputeDerivNeeded().

66  {
67  int32 ans = -1;
68  for (size_t i = 0; i < outputs.size(); i++) {
69  if (outputs[i].name == node_name) {
70  KALDI_ASSERT(ans == -1 && "Two inputs with the same name");
71  ans = i;
72  }
73  }
74  return ans;
75 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< IoSpecification > outputs
bool NeedDerivatives ( ) const

returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true.

Definition at line 28 of file nnet-computation.cc.

References rnnlm::i, ComputationRequest::inputs, KALDI_ERR, ComputationRequest::need_model_derivative, and ComputationRequest::outputs.

Referenced by Compiler::SetUpPrecomputedIndexes().

28  {
29  bool ans = false;
31  ans = true;
32  for (size_t i = 0; i < inputs.size(); i++) {
33  if (inputs[i].has_deriv) { // derivative requested for this input
34  ans = true;
35  break;
36  }
37  }
38  if (ans) {
39  // check that the output actually provides a derivative, else the
40  // request could not be meaningfully satisfied.
41  size_t i;
42  for (i = 0; i < outputs.size(); i++)
43  if (outputs[i].has_deriv)
44  break;
45  if (i == outputs.size()) {
46  KALDI_ERR << "You requested model derivatives or input derivatives, but "
47  << "provide no derivatives at the output.";
48  }
49  }
50  return ans;
51 }
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
std::vector< IoSpecification > inputs
#define KALDI_ERR
Definition: kaldi-error.h:127
std::vector< IoSpecification > outputs
bool operator== ( const ComputationRequest other) const

Definition at line 1088 of file nnet-computation.cc.

References ComputationRequest::inputs, ComputationRequest::misc_info, ComputationRequest::need_model_derivative, ComputationRequest::outputs, and ComputationRequest::store_component_stats.

1088  {
1089  // rely on the std::vector's default implementation of ==, which in turn
1090  // relies on the == operator of class IoSpecification.
1091  return inputs == other.inputs && outputs == other.outputs &&
1092  need_model_derivative == other.need_model_derivative &&
1093  store_component_stats == other.store_component_stats &&
1094  misc_info == other.misc_info;
1095 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
MiscComputationInfo misc_info
misc_info is for extensibility to things that don't easily fit into the framework.
std::vector< IoSpecification > inputs
std::vector< IoSpecification > outputs
void Print ( std::ostream &  os) const

This function is for printing info about the computation request in a human-readable way.

Definition at line 1055 of file nnet-computation.cc.

References rnnlm::i, and NnetComputation::need_model_derivative.

Referenced by kaldi::nnet3::CompileLoopedInternal(), CachingOptimizingCompiler::CompileNoShortcut(), ComputationGraphBuilder::ExplainWhyAllOutputsNotComputable(), kaldi::nnet3::UnitTestNnetCompile(), kaldi::nnet3::UnitTestNnetCompileLooped(), and kaldi::nnet3::UnitTestNnetCompileMulti().

1055  {
1056  os << " # Computation request:\n";
1057  for (size_t i = 0; i < inputs.size(); i++) {
1058  os << "input-" << i << ": ";
1059  inputs[i].Print(os);
1060  }
1061  for (size_t i = 0; i < outputs.size(); i++) {
1062  os << "output-" << i << ": ";
1063  outputs[i].Print(os);
1064  }
1065  os << "need-model-derivative: " <<
1066  (need_model_derivative ? "true\n" : "false\n");
1067  os << "store-component-stats: " <<
1068  (store_component_stats ? "true\n" : "false\n");
1069  misc_info.Print(os);
1070 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
MiscComputationInfo misc_info
misc_info is for extensibility to things that don't easily fit into the framework.
std::vector< IoSpecification > inputs
std::vector< IoSpecification > outputs
void Print(std::ostream &os) const
void Read ( std::istream &  istream,
bool  binary 
)

Definition at line 997 of file nnet-computation.cc.

References kaldi::nnet3::ExpectToken(), KALDI_ASSERT, NnetComputation::need_model_derivative, and kaldi::ReadBasicType().

Referenced by CachingOptimizingCompiler::ReadCache(), and kaldi::nnet3::UnitTestComputationRequestIo().

997  {
998  ExpectToken(is, binary, "<ComputationRequest>");
999  size_t num_inputs;
1000  ExpectToken(is, binary, "<NumInputs>");
1001  ReadBasicType(is, binary, &num_inputs);
1002  KALDI_ASSERT(num_inputs >= 0);
1003  inputs.resize(num_inputs);
1004  ExpectToken(is, binary, "<Inputs>");
1005  for (size_t c = 0; c < num_inputs; c++) {
1006  inputs[c].Read(is, binary);
1007  }
1008 
1009  size_t num_outputs;
1010  ExpectToken(is, binary, "<NumOutputs>");
1011  ReadBasicType(is, binary, &num_outputs);
1012  KALDI_ASSERT(num_outputs >= 0);
1013  outputs.resize(num_outputs);
1014  ExpectToken(is, binary, "<Outputs>");
1015  for (size_t c = 0; c < num_outputs; c++) {
1016  outputs[c].Read(is, binary);
1017  }
1018 
1019  ExpectToken(is, binary, "<NeedModelDerivative>");
1020  ReadBasicType(is, binary, &need_model_derivative);
1021  ExpectToken(is, binary, "<StoreComponentStats>");
1022  ReadBasicType(is, binary, &store_component_stats);
1023  ExpectToken(is, binary, "</ComputationRequest>");
1024 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
std::vector< IoSpecification > inputs
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
std::vector< IoSpecification > outputs
void Write ( std::ostream &  ostream,
bool  binary 
) const

Definition at line 1026 of file nnet-computation.cc.

References NnetComputation::need_model_derivative, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by kaldi::nnet3::UnitTestComputationRequestIo().

1026  {
1027  WriteToken(os, binary, "<ComputationRequest>");
1028  if (!binary) os << std::endl;
1029  WriteToken(os, binary, "<NumInputs>");
1030  WriteBasicType(os, binary, inputs.size());
1031  if (!binary) os << std::endl;
1032  WriteToken(os, binary, "<Inputs>");
1033  for (size_t c = 0; c < inputs.size(); c++) {
1034  inputs[c].Write(os, binary);
1035  }
1036  if (!binary) os << std::endl;
1037 
1038  WriteToken(os, binary, "<NumOutputs>");
1039  WriteBasicType(os, binary, outputs.size());
1040  if (!binary) os << std::endl;
1041  WriteToken(os, binary, "<Outputs>");
1042  for (size_t c = 0; c < outputs.size(); c++) {
1043  outputs[c].Write(os, binary);
1044  }
1045  if (!binary) os << std::endl;
1046 
1047  WriteToken(os, binary, "<NeedModelDerivative>");
1049  WriteToken(os, binary, "<StoreComponentStats>");
1051  WriteToken(os, binary, "</ComputationRequest>");
1052  if (!binary) os << std::endl;
1053 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
std::vector< IoSpecification > inputs
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
std::vector< IoSpecification > outputs
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

bool store_component_stats

you should set need_component_stats to true if you need the average-activation and average-derivative statistics stored by the StoreStats() functions of components/ such as Tanh, Sigmoid and Softmax.

Definition at line 126 of file nnet-computation.h.

Referenced by kaldi::nnet3::ComputeExampleComputationRequestSimple(), DecodableNnetSimple::DoNnetComputation(), kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::GetComputationRequest(), kaldi::nnet3::GetDiscriminativeComputationRequest(), ComputationRequest::operator==(), kaldi::nnet3::RequestIsDecomposable(), kaldi::nnet3::RunNnetComputation(), and kaldi::nnet3::UnitTestNnetCompileMulti().


The documentation for this struct was generated from the following files: