nnet-computation-graph.h
Go to the documentation of this file.
1 // nnet3/nnet-computation-graph.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_COMPUTATION_GRAPH_H_
21 #define KALDI_NNET3_NNET_COMPUTATION_GRAPH_H_
22 
24 #include "nnet3/nnet-nnet.h"
25 #include "nnet3/nnet-computation.h"
26 
27 #include <iostream>
28 #include <deque>
29 
30 namespace kaldi {
31 namespace nnet3 {
32 
44 
46  std::vector<Cindex> cindexes;
47 
55  std::vector<bool> is_input;
56 
63  std::vector<std::vector<int32> > dependencies;
64 
80  std::vector<int32> segment_ends;
81 
86  int32 GetCindexId(const Cindex &cindex, bool is_input, bool *is_new);
87 
90  int32 GetCindexId(const Cindex &cindex) const;
91 
93  // keeping only for which keep[c - start_cindex_id] is
96  void Renumber(int32 start_cindex_id,
97  const std::vector<bool> &keep);
98 
99 
106  void Print(std::ostream &os, const std::vector<std::string> &node_names);
107 
108  private:
111  unordered_map<Cindex, int32, CindexHasher> cindex_to_cindex_id_;
112 };
113 
114 
118  public:
119  ComputationGraphBuilder(const Nnet &nnet,
120  ComputationGraph *graph);
121 
122  // Does the initial computation (populating the graph and computing whether
123  // each required cindex_id is computable), without the pruning. In the normal
124  // case you call this just once with one 'request', but in the 'online' case
125  // you call Compute() [then maybe check AllOutputsAreComputable()] then
126  // Prune() multiple times, with a sequence of different requests for
127  // increasing time values.
128  // Note: it sets the class member request_ to the address of 'request', so
129  // you should not let 'request' go out of scope while this class might
130  // still use it (e.g. until you call Compute() with a different
131  void Compute(const ComputationRequest &request);
132 
133  // Returns true if all requested outputs are computable. To be called after
134  // Compute() but before Prune(().
135  bool AllOutputsAreComputable() const;
136 
137  // Prints logging info to explain why all outputs are not computable.
138  // To be called only if AllOutputsAreComputable() returned false.
139  void ExplainWhyAllOutputsNotComputable() const;
140 
141  // This function outputs to "computable" information about whether each
142  // requested element of each output was computable. "computable" will have
143  // the same size as request_->outputs, and each element will have the same
144  // size as request_->outputs[i].indexes.size(). May only be called after
145  // Compute() but before Prune(). If you have already called Prune(), you can
146  // just assume everything was computable, or else Prune() would have crashed.
147  void GetComputableInfo(std::vector<std::vector<bool> > *computable) const;
148 
149  // to be called after Compute(), this prunes away unused cindex_ids.
150  // If not all the outputs are computable, this will die;
151  // you can check the return status of AllOutputsAreComputable() first if
152  // you want to avoid this.
153  void Prune();
154 
155  // This enum says for each cindex_id, whether we can compute it from the given
156  // inputs or not. Note that there may be situations where before adding
157  // dependencies of a particular cindex_id we realize that we won't be able to
158  // use this cindex_id (i.e. it may be computable but it's not used) because
159  // its usable_count is zero, and in those cases we change the status to
160  // kWillNotCompute even though the cindex-id may be computable. For most
161  // purposes this status is treated the same as kNotComputable.
163  kUnknown = 0,
164  kComputable = 1,
165  kNotComputable = 2,
166  kWillNotCompute = 3
167  };
168 
169  struct CindexInfo {
170  ComputableInfo computable; // kUnknown, kComputable, kNotComputable
171  int32 usable_count; // usable_count_[i] for a cindex_id i is defined as 1 if i is a requested
172  // output, and otherwise as the number of other cindex_ids j such that
173  // computable_info_[j] is not kNotComputable AND usable_count_[j] > 0 AND i is
174  // a member of graph->dependencies[j]. A cindex_id is termed "usable"
175  // (meaning it could potentially participate in the computation of the output)
176  // if its usable_count_ is > 0. This quantity is designed to be easy to keep
177  // updated as we add cindex_ids.
178 
179  // True if in current_queue_ or next_queue_.
180  bool queued;
181 
182  // True if we have created the cindexes that this cindex depends on.
184 
185  CindexInfo(const CindexInfo &other) = default;
186  CindexInfo(): computable(kUnknown),
187  usable_count(0),
188  queued(false),
189  dependencies_computed(false) { }
190  };
191 
192  private:
193  // This function, called from ExplainWhyNotComputable(), prints to "os"
194  // a human-readable form of a given cindex_id, that looks like
195  // some_network_node(n, t, x), e.g. "final_logsoftmax(0, -4, 0)".
196  void PrintCindexId(std::ostream &os, int32 cindex_id) const;
197 
198  // This function, typically to be called just before dying, prints logging
199  // information to explain why the given cindex_id is not computable.
200  void ExplainWhyNotComputable(int32 cindex_id) const;
201 
202  // called at the start of Compute(), this populates the graph (and member
203  // variables) for all the inputs specified in the computation request.
204  void AddInputs();
205 
206  // called at the start of Compute(), this populates the graph (and member
207  // variables, including current_queue_) with all the outputs specified in the
208  // computation request.
209  void AddOutputs();
210 
211  // this does one iteration of building the graph, and increases
212  // current_distance_ by one, i.e. it searches at one more remove from
213  // the output.
214  void BuildGraphOneIter();
215 
216  // (called from BuildGraphOneIter()); make sure the computable_info for
217  // cindex_id is up to date. Has side effects: may update usable_count
218  // values and add things to next_queue_.
219  void UpdateComputableInfo(int32 cindex_id);
220 
221  // (called from BuildGraphOneIter()), this function sets the cindex_id to
222  // status kWillNotCompute and places members of depend_on_this_ into the
223  // computable queue if needed.
224  void SetAsWillNotCompute(int32 cindex_id);
225 
226  // compute and return the ComputableInfo for this cindex_id (kUnknown,
227  // kComputable or kNotComputable).
228  ComputableInfo ComputeComputableInfo(int32 cindex_id) const;
229 
230  // To be called when this cindex_id has just been newly added to graph_, this
231  // function adds a couple default variables associated with it, to *this.
232  inline void AddCindexId(int32 cindex_id);
233 
234  // Add cindex_ids that this cindex_id depends on.
235  void AddDependencies(int32 cindex_id);
236 
237  // increment the "usable" value of this cindex_id.
238  void IncrementUsableCount(int32 cindex_id);
239 
240  // decrement the "usable" value of this cindex_id.
241  void DecrementUsableCount(int32 cindex_id);
242 
243  // This function, called from Prune(), modifies the members of
244  // graph_->dependencies-- it removes those cindexes that are not used in the
245  // computation for the current cindex_id. This will only do something
246  // interesting in cases where there are optional dependencies.
247  // It also clears the dependencies of those cindexes that are not computable.
248  void PruneDependencies(int32 cindex_id);
249 
250  // This function, called from Prune(), computes an array "required", with an
251  // element for each cindex_id that says whether it is required to compute the
252  // requested outputs. This is similar in function to the "usable_count_"
253  // array, but it's more exact because it's computed after we have done
254  // PruneDependencies() to remove unused dependencies, so it will only say
255  // something is required if it is really accessed in the computation.
256  // We'll later use this to remove unnecessary cindexes.
257  // 'start_cindex_id' is the cindex_id from which the 'required' array is
258  // to start (normally zero, but may be nonzero in multi-segment computations);
259  // so 'required' is indexed by cindex_id - start_cindex_id.
260  void ComputeRequiredArray(int32 start_cindex_id,
261  std::vector<bool> *required) const;
262 
263  // this function, to be called from Compute(), does some sanity checks to
264  // verify that the internal state is consistent. It only does this for the
265  // current 'segment' of the computation, starting from 'start_cindex_id' (this
266  // will be 0 in normal, single-segment computations).
267  void Check(int32 start_cindex_id) const;
268 
269  const Nnet &nnet_;
272 
273  // this is the transpose of graph_->dependencies; it tells us
274  // for each cindex_id, which other cindex_ids depend on it.
275  std::vector<std::vector<int32> > depend_on_this_;
276 
277 
278  // this vector is indexed by cindex_id
279  std::vector<CindexInfo> cindex_info_;
280 
281  // current_distance_ >= 0 is the distance to the output, of the cindex_ids in
282  // current_queue_.
284  // the cindex_ids in current_queue_ are at no more than distance
285  // "current_distance" to the output
286  std::vector<int32> current_queue_;
287  // the cindex_ids in next_queue_ are at no more than distance current_distance
288  // + 1 to the output
289  std::vector<int32> next_queue_;
290 };
291 
293 std::ostream& operator << (std::ostream &os,
295 
296 
297 class CindexSet {
298  public:
300  bool operator () (const Cindex &cindex) const;
301 
304  CindexSet(const ComputationGraph &graph);
305 
311  CindexSet(const ComputationGraph &graph,
312  const std::vector<ComputationGraphBuilder::CindexInfo> &info,
313  bool treat_unknown_as_computable);
314  private:
316  const std::vector<ComputationGraphBuilder::CindexInfo> *info_;
318 };
319 
320 
322 class IndexSet {
323  public:
325  bool operator () (const Index &index) const;
326 
332  IndexSet(const ComputationGraph &graph,
333  const std::vector<ComputationGraphBuilder::CindexInfo> &info,
334  int32 node_id,
335  bool treat_unknown_as_computable);
336  private:
338  const std::vector<ComputationGraphBuilder::CindexInfo> &info_;
341 };
342 
343 
344 
345 
377  const Nnet &nnet,
378  const ComputationGraph &computation_graph,
379  std::vector<std::vector<std::vector<int32> > > *phases_per_segment);
380 
381 
416  public:
447  ComputationStepsComputer(const Nnet &nnet,
448  ComputationGraph *graph,
449  std::vector<std::vector<int32> > *steps,
450  std::vector<std::pair<int32, int32> > *locations);
451 
454  void ComputeForSegment(const ComputationRequest &request,
455  const std::vector<std::vector<int32> > &phases);
456 
459  void Check() const;
460  private:
461 
462  // Adds step(s) for one "sub-phase". A sub-phase is the set of cindex_ids from
463  // one phase that have the same node index. Note: for nodes that are
464  // component-input descriptors, we don't actually create the step here, we
465  // create it just before creating the step for its component, and we recreate
466  // the list of cindexes from those from the component. The reason is that
467  // there are situations where doing it directly from the raw_step would not do
468  // the right thing (especially with non-simple components, it's possible that
469  // the cindexes component-input descriptors could be used twice by two
470  // different components)..
471  void ProcessSubPhase(const ComputationRequest &request,
472  const std::vector<Cindex> &sub_phase);
473 
474  // Called from ProcessSubPhase- for the case where it's a DimRangeNode.
475  void ProcessDimRangeSubPhase(const std::vector<Cindex> &sub_phase);
476 
477  // Called from ProcessSubPhase- for the case where it's an input or output node.
478  void ProcessInputOrOutputStep(const ComputationRequest &request,
479  bool is_output,
480  const std::vector<Cindex> &sub_phase);
481 
482  // Called from ProcessSubPhase- for the case where it's a component node.
483  void ProcessComponentStep(const std::vector<Cindex> &step);
484 
485 
486  // Splits a phase up into multiple "sub-phases", which are just the cindexes
487  // from a phase that are from a single node, sorted. At this point we
488  // represent them as Cindexes, not cindex_ids. For efficiency and because it
489  // would be discarded anyway, it discards any raw steps that correspond to
490  // component-input descriptors because these are not processed inside
491  // ProcessSubPhase().
492  void SplitIntoSubPhases(const std::vector<int32> &phase,
493  std::vector<std::vector<Cindex> > *sub_phase) const;
494 
495  // This low-level function used by functions like ProcessComponentStep,
496  // ProcessInputStep and so on, adds one step to 'steps_' (converting from
497  // Cindex to cindex_ids), and updates 'locations' appropriately. It returns
498  // the step index that we just added (== size of steps_ at entry).
499  // If you specify add_if_absent = true, it will add any Cindexes that were
500  // not already present, to the graph. [this option is only to be used
501  // in processing dim-range nodes.
502  int32 AddStep(const std::vector<Cindex> &cindexes,
503  bool add_if_absent = false);
504 
505  // This is an alternative interface to AddStep() that takes a list of
506  // cindex_ids instead of cindexes (it's destructive of that list).
507  int32 AddStep(std::vector<int32> *cindex_ids);
508 
509 
510  // This utility function uses graph_ to convert a vector of cindex_ids into
511  // Cindexes.
512  void ConvertToCindexes(const std::vector<int32> &cindex_ids,
513  std::vector<Cindex> *cindexes) const;
514 
515  // Converts a vector of Cindexes to a vector of Indexes, by
516  // stripping out the node index.
517  static void ConvertToIndexes(const std::vector<Cindex> &cindexes,
518  std::vector<Index> *indexes);
519 
520  // Converts a vector of Indexes to Cindexes, using a supplied
521  // node index.
522  static void ConvertToCindexes(const std::vector<Index> &indexes,
523  int32 node_index,
524  std::vector<Cindex> *cindexes);
525 
526 
527  // This utility function uses graph_ to convert a vector of cindex_ids into
528  // Cindexes. It will crash if the cindexes were not present in the graph.
529  void ConvertToCindexIds(const std::vector<Cindex> &cindexes,
530  std::vector<int32> *cindex_ids) const;
531 
532  // This utility function uses the 'locations_' array to convert the cindex_ids
533  // in 'cindex_ids' into an array (of the same length) of locations, i.e. of
534  // pairs (step, index-into-step), so that if cindex_ids[i] = c, then
535  // (*locations)[i] will be set to (*locations_)[c]. It will die if
536  // one of the locations was not defined, i.e. was the pair (-1, -1).
537  void ConvertToLocations(
538  const std::vector<int32> &cindex_ids,
539  std::vector<std::pair<int32, int32> > *locations) const;
540 
541 
542  const Nnet &nnet_;
545  std::vector<std::vector<int32> > *steps_;
550  std::vector<std::pair<int32, int32> > *locations_;
551 
552 
559  std::unordered_set<std::pair<int32, int32>, PairHasher<int32> > dim_range_nodes_;
560 };
561 
562 
563 
564 } // namespace nnet3
565 } // namespace kaldi
566 
567 
568 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool ConvertToIndexes(const std::vector< std::pair< int32, int32 > > &location_vector, int32 *first_value, std::vector< int32 > *second_values)
If it is the case for some i >= 0 that all the .first elements of "location_vector" are either i or -...
An abstract representation of a set of Indexes.
std::vector< std::pair< int32, int32 > > * locations_
locations_ is a map from cindex_id to the pair of indexes into steps_ where that cindex_id resides...
const std::vector< ComputationGraphBuilder::CindexInfo > * info_
kaldi::int32 int32
std::ostream & operator<<(std::ostream &ostream, const Index &index)
Definition: nnet-common.cc:424
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
const ComputationGraph & graph_
std::pair< int32, Index > Cindex
Definition: nnet-common.h:115
std::vector< Cindex > cindexes
The mapping of cindex_id to Cindex.
std::vector< std::vector< int32 > > dependencies
dependencies[cindex_id] gives you the list of other cindex_ids that this particular cindex_id directl...
int32 GetCindexId(const Cindex &cindex, bool is_input, bool *is_new)
Maps a Cindex to an integer cindex_id.
unordered_map< Cindex, int32, CindexHasher > cindex_to_cindex_id_
Maps each Cindex to an integer cindex_id: reverse mapping of "cindexes".
void ComputeComputationPhases(const Nnet &nnet, const ComputationGraph &graph, std::vector< std::vector< std::vector< int32 > > > *phases_per_segment)
This function divides a computation into &#39;phases&#39;, where a &#39;phase&#39; is a collection of cindexes which ...
std::vector< bool > is_input
For each Cindex this tells us whether it was provided as an input to the network. ...
std::vector< std::vector< int32 > > depend_on_this_
std::vector< int32 > segment_ends
This variable is only of particular interest in a &#39;multi-segment&#39; computation, which is used while cr...
std::vector< std::vector< int32 > > * steps_
steps_ is a pointer to an output that&#39;s passed in in the constructor.
void Renumber(int32 start_cindex_id, const std::vector< bool > &keep)
This function renumbers the cindex-ids (but only those with index c >= start_cindex_id,.
std::unordered_set< std::pair< int32, int32 >, PairHasher< int32 > > dim_range_nodes_
dim_range_nodes_ is used when allocating steps for nodes of type kDimRangeNode.
const std::vector< ComputationGraphBuilder::CindexInfo > & info_
const ComputationGraph & graph_
This class arranges the cindex_ids of the computation into a sequence of lists called "steps"...
void Print(std::ostream &os, const std::vector< std::string > &node_names)
This function, useful for debugging/visualization purposes, prints out a summary of the computation g...
The first step in compilation is to turn the ComputationSpecification into a ComputationGraph, where for each Cindex we have a list of other Cindexes that it depends on.
An abstract representation of a set of Cindexes.
A hashing function-object for pairs of ints.
Definition: stl-utils.h:235