All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
kaldi::nnet3::computation_graph Namespace Reference

Functions

void AddOutputToGraph (const ComputationRequest &request, const Nnet &nnet, ComputationGraph *graph)
 
void AddInputToGraph (const ComputationRequest &request, const Nnet &nnet, ComputationGraph *graph)
 
static void ComputeDependenciesSubset (const ComputationGraph &graph, const std::vector< int32 > &cindex_id_to_epoch, std::vector< std::vector< int32 > > *dependencies_subset)
 This function outputs to dependencies_subset[c], for each cindex_id c, the subset of elements d of graph.dependencies[c] such that cindex_id_to_epoch[d] == cindex_id_to_epoch[c]. More...
 
static void ComputeEpochInfo (const Nnet &nnet, const ComputationGraph &graph, std::vector< int32 > *cindex_id_to_epoch, std::vector< std::vector< int32 > > *epochs, std::vector< bool > *epoch_is_trivial)
 This function computes certain information about "epochs" of cindex_ids. More...
 

Function Documentation

void kaldi::nnet3::computation_graph::AddInputToGraph ( const ComputationRequest &  request,
const Nnet &  nnet,
ComputationGraph *  graph 
)

Definition at line 931 of file nnet-computation-graph.cc.

References ComputationGraph::GetCindexId(), Nnet::GetNode(), Nnet::GetNodeIndex(), rnnlm::i, ComputationRequest::inputs, rnnlm::j, KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kComponent, kaldi::nnet3::kInput, rnnlm::n, and NetworkNode::node_type.

Referenced by kaldi::nnet3::ComputeComputationGraph().

933  {
934  int32 num_added = 0;
935  for (int32 i = 0; i < request.inputs.size(); i++) {
936  int32 n = nnet.GetNodeIndex(request.inputs[i].name);
937  if (n == -1)
938  KALDI_ERR << "Network has no input with name "
939  << request.inputs[i].name;
940  NodeType t = nnet.GetNode(n).node_type;
941  KALDI_ASSERT((t == kInput || t == kComponent) &&
942  "Inputs to graph only allowed for Input and Component nodes.");
943 
944  for (int32 j = 0; j < request.inputs[i].indexes.size(); j++) {
945  Cindex cindex(n, request.inputs[i].indexes[j]);
946  bool is_input = true, is_new;
947  graph->GetCindexId(cindex, is_input, &is_new); // ignore the return value.
948  KALDI_ASSERT(is_new && "Input index seems to be listed more than once");
949  num_added++;
950  }
951  }
952  KALDI_ASSERT(num_added > 0 && "AddInputToGraph: nothing to add.");
953 }
std::pair< int32, Index > Cindex
Definition: nnet-common.h:100
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void kaldi::nnet3::computation_graph::AddOutputToGraph ( const ComputationRequest &  request,
const Nnet &  nnet,
ComputationGraph *  graph 
)

Definition at line 908 of file nnet-computation-graph.cc.

References ComputationGraph::GetCindexId(), Nnet::GetNodeIndex(), rnnlm::i, rnnlm::j, KALDI_ASSERT, KALDI_ERR, rnnlm::n, and ComputationRequest::outputs.

Referenced by kaldi::nnet3::ComputeComputationGraph().

910  {
911  int32 num_added = 0;
912  for (int32 i = 0; i < request.outputs.size(); i++) {
913  int32 n = nnet.GetNodeIndex(request.outputs[i].name);
914  if (n == -1)
915  KALDI_ERR << "Network has no output with name "
916  << request.outputs[i].name;
917  for (int32 j = 0; j < request.outputs[i].indexes.size(); j++) {
918  Cindex cindex(n, request.outputs[i].indexes[j]);
919  bool is_input = false, is_new;
920  graph->GetCindexId(cindex, is_input, &is_new); // ignore the return value.
921  KALDI_ASSERT(is_new && "Output index seems to be listed more than once");
922  num_added++;
923  }
924  }
925  KALDI_ASSERT(num_added > 0 && "AddOutputToGraph: nothing to add.");
926 }
std::pair< int32, Index > Cindex
Definition: nnet-common.h:100
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
static void kaldi::nnet3::computation_graph::ComputeDependenciesSubset ( const ComputationGraph &  graph,
const std::vector< int32 > &  cindex_id_to_epoch,
std::vector< std::vector< int32 > > *  dependencies_subset 
)
static

This function outputs to dependencies_subset[c], for each cindex_id c, the subset of elements d of graph.dependencies[c] such that cindex_id_to_epoch[d] == cindex_id_to_epoch[c].

That is, it's the dependency graph of the entire computation, but removing links that go from one epoch to another epoch. Topologically, 'dependencies_subset' would therefor consist of a bunch of disconnected graphs.

Definition at line 965 of file nnet-computation-graph.cc.

References ComputationGraph::cindexes, rnnlm::d, ComputationGraph::dependencies, rnnlm::i, and KALDI_ASSERT.

Referenced by kaldi::nnet3::ComputeComputationPhases().

968  {
969  int32 num_cindex_ids = graph.cindexes.size();
970  KALDI_ASSERT(cindex_id_to_epoch.size() == num_cindex_ids);
971  dependencies_subset->resize(num_cindex_ids);
972  for (int32 cindex_id = 0; cindex_id < num_cindex_ids; cindex_id++) {
973  int32 phase_index = cindex_id_to_epoch[cindex_id];
974  const std::vector<int32> &dependencies = graph.dependencies[cindex_id];
975  std::vector<int32> &dep_subset = (*dependencies_subset)[cindex_id];
976  int32 num_dep = dependencies.size();
977  for (int32 i = 0; i < num_dep; i++) {
978  int32 d = dependencies[i];
979  if (cindex_id_to_epoch[d] == phase_index)
980  dep_subset.push_back(d);
981  }
982  }
983 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
static void kaldi::nnet3::computation_graph::ComputeEpochInfo ( const Nnet &  nnet,
const ComputationGraph &  graph,
std::vector< int32 > *  cindex_id_to_epoch,
std::vector< std::vector< int32 > > *  epochs,
std::vector< bool > *  epoch_is_trivial 
)
static

This function computes certain information about "epochs" of cindex_ids.

The function ComputeNnetComputationEpochs() from nnet-graph.h gives us a map from the NetworkNode index to an index we call the "epoch" index: basically, nodes that are computed first have a lower epoch index, and all nodes that are part of strongly connected components have the same epoch index. In an acyclic nnet graph each component will usually have its own epoch index, but in things like LSTMs, each LSTM layer (with multiple components) will have its own epoch index.

The overall computation order that we compute, will respect this ordering into epochs (except that outputs of nodes of type kComponent that are actually provided as inputs to the network, won't be subject to these limitations but will come first in the order)... we will just ignore the output of this function as it concerns cindex-ids that are provided as input to the network.

Parameters
nnet[in] The neural net
graph[in] The computation graph
cindex_id_to_epoch[out] A vector that maps cindex_id to epoch index, as obtained by adding one to the output of ComputeNnetComputationOrder; however, input cindex_ids (those for which is_input[cindex_id] is true) always map to 0. Note: the epoch-index only depends on the neural network's topology of nodes; a node in the network should always map to the same epoch-index regardless of the computation, and we assign cindexes to epochs just based on what node the cindexes are part of.
epochs[out] The same information as cindex_id_to_epoch, but in a different format: for each epoch, a list of cindex_ids with that epoch index.
epoch_is_trivial[out] A vector of bool, indexed by epoch index that's true if this epoch index corresponds to just a single NetworkNode. (and also true for epoch index 0, which corresponds only to inputs to the network).

Definition at line 1019 of file nnet-computation-graph.cc.

References ComputationGraph::cindexes, kaldi::nnet3::ComputeNnetComputationEpochs(), rnnlm::i, ComputationGraph::is_input, KALDI_ASSERT, KALDI_VLOG, rnnlm::n, Nnet::NumNodes(), and kaldi::nnet3::PrintIntegerVector().

Referenced by kaldi::nnet3::ComputeComputationPhases().

1024  {
1025 
1026  // node_to_epoch maps each nnet node to an index >= 0 that tells us coarsely
1027  // what order to compute them in... but we may need to compute a finer
1028  // ordering at the cindex_id level in cases like RNNs.
1029  std::vector<int32> node_to_epoch;
1030  ComputeNnetComputationEpochs(nnet, &node_to_epoch);
1031  {
1032  std::ostringstream os;
1033  PrintIntegerVector(os, node_to_epoch);
1034  KALDI_VLOG(6) << "node_to_epoch: " << os.str();
1035  }
1036 
1037  // Add one to the epoch numbering because we will be reserving
1038  // zero for inputs to the network, and we don't want to have to
1039  // prove that epoch number 0 would correspond only to inputs.
1040  for (int32 i = 0; i < node_to_epoch.size(); i++)
1041  node_to_epoch[i]++;
1042  int32 num_nodes = nnet.NumNodes(),
1043  num_cindex_ids = graph.cindexes.size(),
1044  num_epoch_indexes = 1 + *std::max_element(node_to_epoch.begin(),
1045  node_to_epoch.end());
1046  KALDI_ASSERT(node_to_epoch.size() == num_nodes);
1047 
1048  // epoch_to_num_nodes is only used so we know whether each epoch
1049  // index corresponds to multiple nodes; if it's just one node then we know
1050  // the computation is very simple and we can do an optimization.
1051  std::vector<int32> epoch_to_num_nodes(num_epoch_indexes, 0);
1052  for (int32 n = 0; n < num_nodes; n++)
1053  epoch_to_num_nodes[node_to_epoch[n]]++;
1054 
1055  epoch_is_trivial->resize(num_epoch_indexes);
1056  for (int32 o = 0; o < num_epoch_indexes; o++) {
1057  KALDI_ASSERT(o == 0 || epoch_to_num_nodes[o] > 0);
1058  (*epoch_is_trivial)[o] = (epoch_to_num_nodes[o] <= 1);
1059  }
1060 
1061  cindex_id_to_epoch->resize(num_cindex_ids);
1062  epochs->resize(num_epoch_indexes);
1063  for (int32 cindex_id = 0; cindex_id < num_cindex_ids; cindex_id++) {
1064  int32 node_index = graph.cindexes[cindex_id].first,
1065  epoch_index = (graph.is_input[cindex_id] ? 0 :
1066  node_to_epoch[node_index]);
1067  (*cindex_id_to_epoch)[cindex_id] = epoch_index;
1068  (*epochs)[epoch_index].push_back(cindex_id);
1069  }
1070 }
struct rnnlm::@11::@12 n
void ComputeNnetComputationEpochs(const Nnet &nnet, std::vector< int32 > *node_to_epoch)
This function computes the order in which we need to compute each node in the graph, where each node-index n maps to an epoch-index t = 0, 1, ...
Definition: nnet-graph.cc:265
void PrintIntegerVector(std::ostream &os, const std::vector< int32 > &ints)
Definition: nnet-common.cc:448
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136