ComputationRequest Struct Reference

#include <nnet-computation.h>

Collaboration diagram for ComputationRequest:

Public Member Functions

 ComputationRequest ()
 
bool NeedDerivatives () const
 returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true. More...
 
int32 IndexForInput (const std::string &node_name) const
 Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More...
 
int32 IndexForOutput (const std::string &node_name) const
 Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index. More...
 
void Print (std::ostream &os) const
 This function is for printing info about the computation request in a human-readable way. More...
 
void Read (std::istream &istream, bool binary)
 
void Write (std::ostream &ostream, bool binary) const
 
bool operator== (const ComputationRequest &other) const
 

Public Attributes

std::vector< IoSpecificationinputs
 
std::vector< IoSpecificationoutputs
 
bool need_model_derivative
 if need_model_derivative is true, then we'll be doing either model training or model-derivative computation, so updatable components need to be backprop'd. More...
 
bool store_component_stats
 you should set need_component_stats to true if you need the average-activation and average-derivative statistics stored by the StoreStats() functions of components/ such as Tanh, Sigmoid and Softmax. More...
 
MiscComputationInfo misc_info
 misc_info is for extensibility to things that don't easily fit into the framework. More...
 

Detailed Description

Definition at line 114 of file nnet-computation.h.

Constructor & Destructor Documentation

◆ ComputationRequest()

ComputationRequest ( )
inline

Definition at line 132 of file nnet-computation.h.

References MiscComputationInfo::operator==(), and MiscComputationInfo::Print().

132  : need_model_derivative(false),
133  store_component_stats(false) { }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...

Member Function Documentation

◆ IndexForInput()

int32 IndexForInput ( const std::string &  node_name) const

Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.

It is an error if >1 inputs have the same name.

Definition at line 54 of file nnet-computation.cc.

References rnnlm::i, ComputationRequest::inputs, and KALDI_ASSERT.

Referenced by Compiler::ComputeDerivNeeded().

55  {
56  int32 ans = -1;
57  for (size_t i = 0; i < inputs.size(); i++) {
58  if (inputs[i].name == node_name) {
59  KALDI_ASSERT(ans == -1 && "Two inputs with the same name");
60  ans = i;
61  }
62  }
63  return ans;
64 }
kaldi::int32 int32
std::vector< IoSpecification > inputs
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ IndexForOutput()

int32 IndexForOutput ( const std::string &  node_name) const

Returns the index into "inputs" corresponding to the node with name "node_name", or -1 if there is no such index.

It is an error if >1 inputs have the same name.

Definition at line 66 of file nnet-computation.cc.

References rnnlm::i, KALDI_ASSERT, and ComputationRequest::outputs.

Referenced by Compiler::ComputeDerivNeeded().

67  {
68  int32 ans = -1;
69  for (size_t i = 0; i < outputs.size(); i++) {
70  if (outputs[i].name == node_name) {
71  KALDI_ASSERT(ans == -1 && "Two inputs with the same name");
72  ans = i;
73  }
74  }
75  return ans;
76 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< IoSpecification > outputs

◆ NeedDerivatives()

bool NeedDerivatives ( ) const

returns true if any of inputs[*].has_deriv is true, or need_model_derivative is true.

Definition at line 29 of file nnet-computation.cc.

References rnnlm::i, ComputationRequest::inputs, KALDI_ERR, ComputationRequest::need_model_derivative, and ComputationRequest::outputs.

Referenced by Compiler::SetUpPrecomputedIndexes().

29  {
30  bool ans = false;
32  ans = true;
33  for (size_t i = 0; i < inputs.size(); i++) {
34  if (inputs[i].has_deriv) { // derivative requested for this input
35  ans = true;
36  break;
37  }
38  }
39  if (ans) {
40  // check that the output actually provides a derivative, else the
41  // request could not be meaningfully satisfied.
42  size_t i;
43  for (i = 0; i < outputs.size(); i++)
44  if (outputs[i].has_deriv)
45  break;
46  if (i == outputs.size()) {
47  KALDI_ERR << "You requested model derivatives or input derivatives, but "
48  << "provide no derivatives at the output.";
49  }
50  }
51  return ans;
52 }
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...
std::vector< IoSpecification > inputs
#define KALDI_ERR
Definition: kaldi-error.h:147
std::vector< IoSpecification > outputs

◆ operator==()

bool operator== ( const ComputationRequest other) const

Definition at line 1125 of file nnet-computation.cc.

References ComputationRequest::inputs, ComputationRequest::misc_info, ComputationRequest::need_model_derivative, ComputationRequest::outputs, and ComputationRequest::store_component_stats.

1125  {
1126  // rely on the std::vector's default implementation of ==, which in turn
1127  // relies on the == operator of class IoSpecification.
1128  return inputs == other.inputs && outputs == other.outputs &&
1129  need_model_derivative == other.need_model_derivative &&
1130  store_component_stats == other.store_component_stats &&
1131  misc_info == other.misc_info;
1132 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...
MiscComputationInfo misc_info
misc_info is for extensibility to things that don&#39;t easily fit into the framework.
std::vector< IoSpecification > inputs
std::vector< IoSpecification > outputs

◆ Print()

void Print ( std::ostream &  os) const

This function is for printing info about the computation request in a human-readable way.

Definition at line 1092 of file nnet-computation.cc.

References rnnlm::i, ComputationRequest::inputs, ComputationRequest::misc_info, ComputationRequest::need_model_derivative, ComputationRequest::outputs, MiscComputationInfo::Print(), and ComputationRequest::store_component_stats.

Referenced by kaldi::nnet3::CompileLoopedInternal(), CachingOptimizingCompiler::CompileNoShortcut(), kaldi::nnet3::UnitTestNnetCompile(), kaldi::nnet3::UnitTestNnetCompileLooped(), and kaldi::nnet3::UnitTestNnetCompileMulti().

1092  {
1093  os << " # Computation request:\n";
1094  for (size_t i = 0; i < inputs.size(); i++) {
1095  os << "input-" << i << ": ";
1096  inputs[i].Print(os);
1097  }
1098  for (size_t i = 0; i < outputs.size(); i++) {
1099  os << "output-" << i << ": ";
1100  outputs[i].Print(os);
1101  }
1102  os << "need-model-derivative: " <<
1103  (need_model_derivative ? "true\n" : "false\n");
1104  os << "store-component-stats: " <<
1105  (store_component_stats ? "true\n" : "false\n");
1106  misc_info.Print(os);
1107 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...
MiscComputationInfo misc_info
misc_info is for extensibility to things that don&#39;t easily fit into the framework.
std::vector< IoSpecification > inputs
std::vector< IoSpecification > outputs
void Print(std::ostream &os) const

◆ Read()

void Read ( std::istream &  istream,
bool  binary 
)

Definition at line 1034 of file nnet-computation.cc.

References kaldi::nnet3::ExpectToken(), ComputationRequest::inputs, KALDI_ASSERT, ComputationRequest::need_model_derivative, ComputationRequest::outputs, kaldi::ReadBasicType(), and ComputationRequest::store_component_stats.

Referenced by ComputationCache::Read(), and kaldi::nnet3::UnitTestComputationRequestIo().

1034  {
1035  ExpectToken(is, binary, "<ComputationRequest>");
1036  size_t num_inputs;
1037  ExpectToken(is, binary, "<NumInputs>");
1038  ReadBasicType(is, binary, &num_inputs);
1039  KALDI_ASSERT(num_inputs >= 0);
1040  inputs.resize(num_inputs);
1041  ExpectToken(is, binary, "<Inputs>");
1042  for (size_t c = 0; c < num_inputs; c++) {
1043  inputs[c].Read(is, binary);
1044  }
1045 
1046  size_t num_outputs;
1047  ExpectToken(is, binary, "<NumOutputs>");
1048  ReadBasicType(is, binary, &num_outputs);
1049  KALDI_ASSERT(num_outputs >= 0);
1050  outputs.resize(num_outputs);
1051  ExpectToken(is, binary, "<Outputs>");
1052  for (size_t c = 0; c < num_outputs; c++) {
1053  outputs[c].Read(is, binary);
1054  }
1055 
1056  ExpectToken(is, binary, "<NeedModelDerivative>");
1057  ReadBasicType(is, binary, &need_model_derivative);
1058  ExpectToken(is, binary, "<StoreComponentStats>");
1059  ReadBasicType(is, binary, &store_component_stats);
1060  ExpectToken(is, binary, "</ComputationRequest>");
1061 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
std::vector< IoSpecification > inputs
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< IoSpecification > outputs

◆ Write()

void Write ( std::ostream &  ostream,
bool  binary 
) const

Definition at line 1063 of file nnet-computation.cc.

References ComputationRequest::inputs, ComputationRequest::need_model_derivative, ComputationRequest::outputs, ComputationRequest::store_component_stats, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by kaldi::nnet3::UnitTestComputationRequestIo().

1063  {
1064  WriteToken(os, binary, "<ComputationRequest>");
1065  if (!binary) os << std::endl;
1066  WriteToken(os, binary, "<NumInputs>");
1067  WriteBasicType(os, binary, inputs.size());
1068  if (!binary) os << std::endl;
1069  WriteToken(os, binary, "<Inputs>");
1070  for (size_t c = 0; c < inputs.size(); c++) {
1071  inputs[c].Write(os, binary);
1072  }
1073  if (!binary) os << std::endl;
1074 
1075  WriteToken(os, binary, "<NumOutputs>");
1076  WriteBasicType(os, binary, outputs.size());
1077  if (!binary) os << std::endl;
1078  WriteToken(os, binary, "<Outputs>");
1079  for (size_t c = 0; c < outputs.size(); c++) {
1080  outputs[c].Write(os, binary);
1081  }
1082  if (!binary) os << std::endl;
1083 
1084  WriteToken(os, binary, "<NeedModelDerivative>");
1086  WriteToken(os, binary, "<StoreComponentStats>");
1088  WriteToken(os, binary, "</ComputationRequest>");
1089  if (!binary) os << std::endl;
1090 }
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we&#39;ll be doing either model training or model-derivative compu...
std::vector< IoSpecification > inputs
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
std::vector< IoSpecification > outputs
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ inputs

◆ misc_info

◆ need_model_derivative

◆ outputs

◆ store_component_stats


The documentation for this struct was generated from the following files: