nnet-graph.cc
Go to the documentation of this file.
1 // nnet3/nnet-graph.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Guoguo Chen
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <iterator>
22 #include <sstream>
23 #include "nnet3/nnet-graph.h"
24 
25 namespace kaldi {
26 namespace nnet3 {
27 
28 
29 
30 void NnetToDirectedGraph(const Nnet &nnet,
31  std::vector<std::vector<int32> > *graph) {
32  graph->clear();
33  int32 num_nodes = nnet.NumNodes();
34  graph->resize(num_nodes);
35  for (int32 n = 0; n < num_nodes; n++) {
36  const NetworkNode &node = nnet.GetNode(n);
37  // handle dependencies of this node.
38  std::vector<int32> node_dependencies;
39  switch (node.node_type) {
40  case kInput:
41  break; // no node dependencies.
42  case kDescriptor:
43  node.descriptor.GetNodeDependencies(&node_dependencies);
44  break;
45  case kComponent:
46  node_dependencies.push_back(n - 1);
47  break;
48  case kDimRange:
49  node_dependencies.push_back(node.u.node_index);
50  break;
51  default:
52  KALDI_ERR << "Invalid node type";
53  }
54  SortAndUniq(&node_dependencies);
55  for (size_t i = 0; i < node_dependencies.size(); i++) {
56  int32 dep_n = node_dependencies[i];
57  KALDI_ASSERT(dep_n >= 0 && dep_n < num_nodes);
58  (*graph)[dep_n].push_back(n);
59  }
60  }
61 }
62 
63 void ComputeGraphTranspose(const std::vector<std::vector<int32> > &graph,
64  std::vector<std::vector<int32> > *graph_transpose) {
65  int32 size = graph.size();
66  graph_transpose->clear();
67  graph_transpose->resize(size);
68  for (int32 n = 0; n < size; n++) {
69  const std::vector<int32> &nodes = graph[n];
70  std::vector<int32>::const_iterator iter = nodes.begin(), end = nodes.end();
71  for (; iter != end; ++iter) {
72  int32 dest = *iter;
73  (*graph_transpose)[dest].push_back(n);
74  }
75  }
76 }
77 
78 struct TarjanNode {
81  bool on_stack;
82  TarjanNode() : index(-1), lowlink(-1), on_stack(false) {}
83 };
84 
86  const std::vector<std::vector<int32> > &graph,
87  int32 *global_index,
88  std::vector<TarjanNode> *tarjan_nodes,
89  std::vector<int32> *tarjan_stack,
90  std::vector<std::vector<int32> > *sccs) {
91  KALDI_ASSERT(sccs != NULL);
92  KALDI_ASSERT(tarjan_nodes != NULL);
93  KALDI_ASSERT(tarjan_stack != NULL);
94  KALDI_ASSERT(global_index != NULL);
95  KALDI_ASSERT(node >= 0 && node < graph.size());
96 
97  // Initializes the current Tarjan node.
98  (*tarjan_nodes)[node].index = *global_index;
99  (*tarjan_nodes)[node].lowlink = *global_index;
100  *global_index += 1;
101  (*tarjan_nodes)[node].on_stack = true;
102  tarjan_stack->push_back(node);
103 
104  // DFS from the current node.
105  for (int32 i = 0; i < graph[node].size(); ++i) {
106  int32 next = graph[node][i];
107 
108  if ((*tarjan_nodes)[next].index == -1) {
109  // First time we see this node.
110  TarjanSccRecursive(next, graph,
111  global_index, tarjan_nodes, tarjan_stack, sccs);
112  (*tarjan_nodes)[node].lowlink = std::min((*tarjan_nodes)[node].lowlink,
113  (*tarjan_nodes)[next].lowlink);
114  } else if ((*tarjan_nodes)[next].on_stack) {
115  // Next node is on the stack -- back edge. We can't use the lowlink of
116  // next node, because that may point to the index of the root, while the
117  // current node can't be the root.
118  (*tarjan_nodes)[node].lowlink = std::min((*tarjan_nodes)[node].lowlink,
119  (*tarjan_nodes)[next].index);
120  }
121  }
122 
123  // Output SCC.
124  if ((*tarjan_nodes)[node].index == (*tarjan_nodes)[node].lowlink) {
125  std::vector<int32> scc;
126  int32 pop_node;
127  do {
128  pop_node = tarjan_stack->back();
129  tarjan_stack->pop_back();
130  (*tarjan_nodes)[pop_node].on_stack = false;
131  scc.push_back(pop_node);
132  } while (pop_node != node);
133  KALDI_ASSERT(pop_node == node);
134  sccs->push_back(scc);
135  }
136 }
137 
138 void FindSccsTarjan(const std::vector<std::vector<int32> > &graph,
139  std::vector<std::vector<int32> > *sccs) {
140  KALDI_ASSERT(sccs != NULL);
141 
142  // Initialization.
143  std::vector<TarjanNode> tarjan_nodes(graph.size());
144  std::vector<int32> tarjan_stack;
145  int32 global_index = 0;
146 
147  // Calls the recursive function.
148  for (int32 n = 0; n < graph.size(); ++n) {
149  if (tarjan_nodes[n].index == -1) {
150  TarjanSccRecursive(n, graph,
151  &global_index, &tarjan_nodes, &tarjan_stack, sccs);
152  }
153  }
154 }
155 
156 void FindSccs(const std::vector<std::vector<int32> > &graph,
157  std::vector<std::vector<int32> > *sccs) {
158  // Internally we call Tarjan's SCC algorithm, as it only requires one DFS. We
159  // can change this to other methods later on if necessary.
160  KALDI_ASSERT(sccs != NULL);
161  FindSccsTarjan(graph, sccs);
162 }
163 
164 void MakeSccGraph(const std::vector<std::vector<int32> > &graph,
165  const std::vector<std::vector<int32> > &sccs,
166  std::vector<std::vector<int32> > *scc_graph) {
167  KALDI_ASSERT(scc_graph != NULL);
168  scc_graph->clear();
169  scc_graph->resize(sccs.size());
170 
171  // Hash map from node to SCC index.
172  std::vector<int32> node_to_scc_index(graph.size());
173  for (int32 i = 0; i < sccs.size(); ++i) {
174  for (int32 j = 0; j < sccs[i].size(); ++j) {
175  KALDI_ASSERT(sccs[i][j] >= 0 && sccs[i][j] < graph.size());
176  node_to_scc_index[sccs[i][j]] = i;
177  }
178  }
179 
180  // Builds graph.
181  for (int32 i = 0; i < sccs.size(); ++i) {
182  for (int32 j = 0; j < sccs[i].size(); ++j) {
183  int32 node = sccs[i][j];
184  KALDI_ASSERT(node >= 0 && node < graph.size());
185  for (int32 k = 0; k < graph[node].size(); ++k) {
186  if (node_to_scc_index[graph[node][k]] != i) { // Exclucding self.
187  (*scc_graph)[i].push_back(node_to_scc_index[graph[node][k]]);
188  }
189  }
190  }
191  // If necessary, we can use a hash maps to avoid this sorting.
192  SortAndUniq(&((*scc_graph)[i]));
193  }
194 }
195 
197  const std::vector<std::vector<int32> > &graph,
198  std::vector<bool> *cycle_detector,
199  std::vector<bool> *is_visited,
200  std::vector<int32> *reversed_orders) {
201  KALDI_ASSERT(node >= 0 && node < graph.size());
202  KALDI_ASSERT(cycle_detector != NULL);
203  KALDI_ASSERT(is_visited != NULL);
204  KALDI_ASSERT(reversed_orders != NULL);
205  if ((*cycle_detector)[node]) {
206  KALDI_ERR << "Cycle detected when computing the topological sorting order";
207  }
208 
209  if (!(*is_visited)[node]) {
210  (*cycle_detector)[node] = true;
211  for (int32 i = 0; i < graph[node].size(); ++i) {
212  ComputeTopSortOrderRecursive(graph[node][i], graph,
213  cycle_detector, is_visited, reversed_orders);
214  }
215  (*cycle_detector)[node] = false;
216  (*is_visited)[node] = true;
217  // At this point we have added all the children to <reversed_orders>, so we
218  // can add the current now.
219  reversed_orders->push_back(node);
220  }
221 }
222 
223 void ComputeTopSortOrder(const std::vector<std::vector<int32> > &graph,
224  std::vector<int32> *node_to_order) {
225  // Internally we use DFS, but we only put the node to <node_to_order> when all
226  // its parents have been visited.
227  KALDI_ASSERT(node_to_order != NULL);
228  node_to_order->resize(graph.size());
229 
230  std::vector<bool> cycle_detector(graph.size(), false);
231  std::vector<bool> is_visited(graph.size(), false);
232 
233  std::vector<int32> reversed_orders;
234  for(int32 i = 0; i < graph.size(); ++i) {
235  if (!is_visited[i]) {
236  ComputeTopSortOrderRecursive(i, graph, &cycle_detector,
237  &is_visited, &reversed_orders);
238  }
239  }
240 
241  KALDI_ASSERT(node_to_order->size() == reversed_orders.size());
242  for (int32 i = 0; i < reversed_orders.size(); ++i) {
243  KALDI_ASSERT(reversed_orders[i] >= 0 && reversed_orders[i] < graph.size());
244  (*node_to_order)[reversed_orders[i]] = graph.size() - i - 1;
245  }
246 }
247 
248 std::string PrintGraphToString(const std::vector<std::vector<int32> > &graph) {
249  std::ostringstream os;
250  int32 num_nodes = graph.size();
251  for (int32 i = 0; i < num_nodes; i++) {
252  os << i << " -> (";
253  const std::vector<int32> &vec = graph[i];
254  int32 size = vec.size();
255  for (int32 j = 0; j < size; j++) {
256  os << vec[j];
257  if (j + 1 < size) os << ",";
258  }
259  os << ")";
260  if (i + 1 < num_nodes) os << "; ";
261  }
262  return os.str();
263 }
264 
266  std::vector<int32> *node_to_epoch) {
267  KALDI_ASSERT(node_to_epoch != NULL);
268 
269  std::vector<std::vector<int32> > graph;
270  NnetToDirectedGraph(nnet, &graph);
271  KALDI_VLOG(6) << "graph is: " << PrintGraphToString(graph);
272 
273  std::vector<std::vector<int32> > sccs;
274  FindSccs(graph, &sccs);
275 
276  std::vector<std::vector<int32> > scc_graph;
277  MakeSccGraph(graph, sccs, &scc_graph);
278  KALDI_VLOG(6) << "scc graph is: " << PrintGraphToString(scc_graph);
279 
280  std::vector<int32> scc_node_to_epoch;
281  ComputeTopSortOrder(scc_graph, &scc_node_to_epoch);
282  if (GetVerboseLevel() >= 6) {
283  std::ostringstream os;
284  for (int32 i = 0; i < scc_node_to_epoch.size(); i++)
285  os << scc_node_to_epoch[i] << ", ";
286  KALDI_VLOG(6) << "scc_node_to_epoch is: " << os.str();
287  }
288 
289  node_to_epoch->clear();
290  node_to_epoch->resize(graph.size());
291  for (int32 i = 0; i < sccs.size(); ++i) {
292  for (int32 j = 0; j < sccs[i].size(); ++j) {
293  int32 node = sccs[i][j];
294  KALDI_ASSERT(node >= 0 && node < graph.size());
295  (*node_to_epoch)[node] = scc_node_to_epoch[i];
296  }
297  }
298 }
299 
300 bool GraphHasCycles(const std::vector<std::vector<int32> > &graph) {
301  std::vector<std::vector<int32> > sccs;
302  FindSccs(graph, &sccs);
303  for (size_t i = 0; i < sccs.size(); i++) {
304  if (sccs[i].size() > 1)
305  return true;
306  }
307  // the next code checks for links from a state to itself.
308  int32 num_nodes = graph.size();
309  for (size_t i = 0; i < num_nodes; i++)
310  for (std::vector<int32>::const_iterator iter = graph[i].begin(),
311  end = graph[i].end(); iter != end; ++iter)
312  if (*iter == i) return true;
313  return false;
314 }
315 
316 } // namespace nnet3
317 } // namespace kaldi
void NnetToDirectedGraph(const Nnet &nnet, std::vector< std::vector< int32 > > *graph)
This function takes an nnet and turns it to a directed graph on nodes.
Definition: nnet-graph.cc:30
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 NumNodes() const
Definition: nnet-nnet.h:126
void ComputeTopSortOrderRecursive(int32 node, const std::vector< std::vector< int32 > > &graph, std::vector< bool > *cycle_detector, std::vector< bool > *is_visited, std::vector< int32 > *reversed_orders)
Definition: nnet-graph.cc:196
void GetNodeDependencies(std::vector< int32 > *node_indexes) const
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
kaldi::int32 int32
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq&#39;s (removes duplicates) from a vector.
Definition: stl-utils.h:39
void ComputeGraphTranspose(const std::vector< std::vector< int32 > > &graph, std::vector< std::vector< int32 > > *graph_transpose)
Outputs a graph in which the order of arcs is reversed.
Definition: nnet-graph.cc:63
bool GraphHasCycles(const std::vector< std::vector< int32 > > &graph)
This function returns &#39;true&#39; if the graph represented in &#39;graph&#39; contains cycles (including cycles wh...
Definition: nnet-graph.cc:300
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
Definition: nnet-nnet.h:146
void ComputeTopSortOrder(const std::vector< std::vector< int32 > > &graph, std::vector< int32 > *node_to_order)
Given an acyclic graph (where each std::vector<int32> is a list of destination-nodes of arcs coming f...
Definition: nnet-graph.cc:223
struct rnnlm::@11::@12 n
void ComputeNnetComputationEpochs(const Nnet &nnet, std::vector< int32 > *node_to_epoch)
This function computes the order in which we need to compute each node in the graph, where each node-index n maps to an epoch-index t = 0, 1, ...
Definition: nnet-graph.cc:265
#define KALDI_ERR
Definition: kaldi-error.h:147
void FindSccs(const std::vector< std::vector< int32 > > &graph, std::vector< std::vector< int32 > > *sccs)
Given a directed graph (where each std::vector<int32> is a list of destination-nodes of arcs coming f...
Definition: nnet-graph.cc:156
NetworkNode is used to represent, three types of thing: either an input of the network (which pretty ...
Definition: nnet-nnet.h:81
This file contains a few functions that treat the neural net as a graph on nodes: e...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void MakeSccGraph(const std::vector< std::vector< int32 > > &graph, const std::vector< std::vector< int32 > > &sccs, std::vector< std::vector< int32 > > *scc_graph)
Given a list of sccs of a graph (e.g.
Definition: nnet-graph.cc:164
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void TarjanSccRecursive(int32 node, const std::vector< std::vector< int32 > > &graph, int32 *global_index, std::vector< TarjanNode > *tarjan_nodes, std::vector< int32 > *tarjan_stack, std::vector< std::vector< int32 > > *sccs)
Definition: nnet-graph.cc:85
union kaldi::nnet3::NetworkNode::@15 u
void FindSccsTarjan(const std::vector< std::vector< int32 > > &graph, std::vector< std::vector< int32 > > *sccs)
Definition: nnet-graph.cc:138
std::string PrintGraphToString(const std::vector< std::vector< int32 > > &graph)
Prints a graph to a string in a pretty way for human readability, e.g.
Definition: nnet-graph.cc:248