nnet-nnet.h
Go to the documentation of this file.
1 // nnet3/nnet-nnet.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Daniel Galvez
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_NNET_H_
21 #define KALDI_NNET3_NNET_NNET_H_
22 
23 #include "base/kaldi-common.h"
24 #include "util/kaldi-io.h"
25 #include "matrix/matrix-lib.h"
26 #include "nnet3/nnet-common.h"
28 #include "nnet3/nnet-descriptor.h"
29 
30 #include <iostream>
31 #include <sstream>
32 #include <vector>
33 #include <map>
34 
35 namespace kaldi {
36 namespace nnet3 {
37 
38 
39 
53 
54 
56 
57 
58 
81 struct NetworkNode {
83  // "descriptor" is relevant only for nodes of type kDescriptor.
85  union {
86  // For kComponent, the index into Nnet::components_
88  // for kDimRange, the node-index of the input node, which must be of
89  // type kComponent or kInput.
91 
92  // for nodes of type kDescriptor that are output nodes (i.e. not followed by
93  // a node of type kComponents), the objective function associated with the
94  // output. The core parts of the nnet code just ignore; it is required only
95  // for the information of the calling code, which is perfectly free to
96  // ignore it. View it as a kind of annotation.
98  } u;
99  // for kInput, the dimension of the input feature. For kDimRange, the dimension
100  // of the output (i.e. the length of the range)
102  // for kDimRange, the dimension of the offset into the input component's feature.
104 
105  int32 Dim(const Nnet &nnet) const; // Dimension that this node outputs.
106 
107  NetworkNode(NodeType nt = kNone):
108  node_type(nt), dim(-1), dim_offset(-1) { u.component_index = -1; }
109  NetworkNode(const NetworkNode &other); // copy constructor.
110  // use default assignment operator
111 };
112 
113 
114 
115 class Nnet {
116  public:
117  // This function can be used either to initialize a new Nnet from a config
118  // file, or to add to an existing Nnet, possibly replacing certain parts of
119  // it. It will die with error if something went wrong.
120  // Also see the function ReadEditConfig() in nnet-utils.h (it's made a
121  // non-member because it doesn't need special access).
122  void ReadConfig(std::istream &config_file);
123 
124  int32 NumComponents() const { return components_.size(); }
125 
126  int32 NumNodes() const { return nodes_.size(); }
127 
129  Component *GetComponent(int32 c);
130 
133  const Component *GetComponent(int32 c) const;
134 
138  void SetComponent(int32 c, Component *component);
139 
143  int32 AddComponent(const std::string &name, Component *component);
144 
146  const NetworkNode &GetNode(int32 node) const {
147  KALDI_ASSERT(node >= 0 && node < nodes_.size());
148  return nodes_[node];
149  }
150 
153  KALDI_ASSERT(node >= 0 && node < nodes_.size());
154  return nodes_[node];
155  }
156 
159  bool IsComponentNode(int32 node) const;
160 
163  bool IsDimRangeNode(int32 node) const;
164 
167  bool IsInputNode(int32 node) const;
168 
172  bool IsDescriptorNode(int32 node) const;
173 
176  bool IsOutputNode(int32 node) const;
177 
180  bool IsComponentInputNode(int32 node) const;
181 
183  const std::vector<std::string> &GetNodeNames() const;
184 
186  const std::string &GetNodeName(int32 node_index) const;
187 
191  void SetNodeName(int32 node_index, const std::string &new_name);
192 
194  const std::vector<std::string> &GetComponentNames() const;
195 
197  const std::string &GetComponentName(int32 component_index) const;
198 
200  int32 GetNodeIndex(const std::string &node_name) const;
201 
203  int32 GetComponentIndex(const std::string &node_name) const;
204 
205  // This convenience function returns the dimension of the input with name
206  // "input_name" (e.g. input_name="input" or "ivector"), or -1 if there is no
207  // such input.
208  int32 InputDim(const std::string &input_name) const;
209 
210  // This convenience function returns the dimension of the output with
211  // name "input_name" (e.g. output_name="input"), or -1 if there is
212  // no such input.
213  int32 OutputDim(const std::string &output_name) const;
214 
215  void Read(std::istream &istream, bool binary);
216 
217  void Write(std::ostream &ostream, bool binary) const;
218 
223  void Check(bool warn_for_orphans = true) const;
224 
229  std::string Info() const;
230 
235  int32 Modulus() const;
236 
237  ~Nnet() { Destroy(); }
238 
239  // Default constructor
240  Nnet() { }
241 
242 
243  // Copy constructor
244  Nnet(const Nnet &nnet);
245 
246  Nnet *Copy() const { return new Nnet(*this); }
247 
248  void Swap(Nnet *other);
249 
250  // Assignment operator
251  Nnet& operator =(const Nnet &nnet);
252 
253  // Removes nodes that are never needed to compute any output.
254  void RemoveOrphanNodes(bool remove_orphan_inputs = false);
255 
256  // Removes components that are not used by any node.
257  void RemoveOrphanComponents();
258 
259  // Removes some nodes. This is not to be called without a lot of thought,
260  // as it could ruin the graph structure if done carelessly.
261  void RemoveSomeNodes(const std::vector<int32> &nodes_to_remove);
262 
263  void ResetGenerators(); // resets random-number generators for all
264  // random components. You must call srand() prior to this call, for this to
265  // be effective.
266 
267 
268  // This function outputs to "config_lines" the lines of a config file. If you
269  // provide include_dim=false, this will enable you to reconstruct the nodes in
270  // the network (but not the components, which need to be written separately).
271  // If you provide include_dim=true, it also adds extra information about
272  // node dimensions which is useful for a human reader but won't be
273  // accepted as the config-file format.
274  void GetConfigLines(bool include_dim,
275  std::vector<std::string> *config_lines) const;
276 
277  private:
278 
279  void Destroy();
280 
281  // This function returns as a string the contents of a line of a config-file
282  // corresponding to the node indexed "node_index", which must not be of type
283  // kComponentInput. If include_dim=false, it appears in the same format as it
284  // would appear in a line of a config-file; if include_dim=true, we also
285  // include dimension information that would not be provided in a config file.
286  std::string GetAsConfigLine(int32 node_index, bool include_dim) const;
287 
288 
289  // This function is used when reading config files; it exists in order to
290  // handle replacement of existing nodes. The two input vectors have the same
291  // size. Its job is to remove redundant lines that do not have "component" as
292  // first_token, and where two lines have a configuration value name=xxx in the
293  // config with the same name. In this case it removes the first of the two,
294  // but that first one must have index less than num_lines_initial, else it is
295  // an error.
296  // This function also checks that all lines have a config name=xxx, that
297  // IsValidName(xxx) is true, and that there are no two lines with "component"
298  // as the first token and with the same config name=xxx. Note: here, "name"
299  // means literally "name", but "xxx" stands in for the actual name,
300  // e.g. "my-funky-component."
301  static void RemoveRedundantConfigLines(int32 num_lines_initial,
302  std::vector<ConfigLine> *config_lines);
303 
304  void ProcessComponentConfigLine(int32 initial_num_components,
305  ConfigLine *config);
306  void ProcessComponentNodeConfigLine(int32 pass,
307  ConfigLine *config);
308  void ProcessInputNodeConfigLine(ConfigLine *config);
309  void ProcessOutputNodeConfigLine(int32 pass,
310  ConfigLine *config);
311  void ProcessDimRangeNodeConfigLine(int32 pass,
312  ConfigLine *config);
313 
314  // This function output to "modified_node_names" a modified copy of
315  // node_names_, in which all nodes which are not of type kComponent, kInput or
316  // kDimRange are replaced with the string "***". This is useful when parsing
317  // Descriptors, to avoid inadvertently accepting nodes of invalid types where
318  // they are not allowed.
319  void GetSomeNodeNames(std::vector<std::string> *modified_node_names) const;
320 
321 
322  // the names of the components of the network. Note, these may be distinct
323  // from the network node names below (and live in a different namespace); the
324  // same component may be used in multiple network nodes, to define parameter
325  // sharing.
326  std::vector<std::string> component_names_;
327 
328  // the components of the nnet, in arbitrary order. The network topology is
329  // defined separately, below; a given Component may appear more than once in
330  // the network if necessary for parameter tying.
331  std::vector<Component*> components_;
332 
333  // names of network nodes, i.e. inputs, components and outputs, used only in
334  // reading and writing code. Indexed by network-node index. Note,
335  // components' names are always listed twice, once as foo-input and once as
336  // foo, because the input to a component always gets its own NetworkNode index.
337  std::vector<std::string> node_names_;
338 
339  // the network nodes of the network.
340  std::vector<NetworkNode> nodes_;
341 
342 };
343 
344 
345 } // namespace nnet3
346 } // namespace kaldi
347 
348 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 NumNodes() const
Definition: nnet-nnet.h:126
NetworkNode & GetNode(int32 node)
Non-const accessor for the node... use with extreme caution.
Definition: nnet-nnet.h:152
Abstract base-class for neural-net components.
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
ObjectiveType objective_type
Definition: nnet-nnet.h:97
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
void ResetGenerators(Nnet *nnet)
This function calls &#39;ResetGenerator()&#39; on all components in &#39;nnet&#39; that inherit from class RandomComp...
Definition: nnet-utils.cc:582
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
Definition: nnet-nnet.h:146
int32 Dim(const Nnet &nnet) const
Definition: nnet-nnet.cc:33
Nnet * Copy() const
Definition: nnet-nnet.h:246
NetworkNode is used to represent, three types of thing: either an input of the network (which pretty ...
Definition: nnet-nnet.h:81
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it&#39;s for the con...
Definition: nnet-nnet.h:52
int32 NumComponents() const
Definition: nnet-nnet.h:124
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing=&#39;a b c&#39; baz="a b c d=&#39;a b&#39; e" and giving you access to the fields, in this case.
Definition: text-utils.h:205
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
union kaldi::nnet3::NetworkNode::@15 u
This file contains class definitions for classes ForwardingDescriptor, SumDescriptor and Descriptor...
NetworkNode(NodeType nt=kNone)
Definition: nnet-nnet.h:107