nnet-nnet.cc
Go to the documentation of this file.
1 // nnet3/nnet-nnet.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Daniel Galvez
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <iterator>
21 #include <sstream>
22 #include "nnet3/nnet-nnet.h"
23 #include "nnet3/nnet-parse.h"
24 #include "nnet3/nnet-utils.h"
26 #include "nnet3/am-nnet-simple.h"
27 #include "hmm/transition-model.h"
28 
29 namespace kaldi {
30 namespace nnet3 {
31 
32 // returns dimension that this node outputs.
33 int32 NetworkNode::Dim(const Nnet &nnet) const {
34  int32 ans;
35  switch (node_type) {
36  case kInput: case kDimRange:
37  ans = dim;
38  break;
39  case kDescriptor:
40  ans = descriptor.Dim(nnet);
41  break;
42  case kComponent:
43  ans = nnet.GetComponent(u.component_index)->OutputDim();
44  break;
45  default:
46  ans = 0; // suppress compiler warning
47  KALDI_ERR << "Invalid node type.";
48  }
49  KALDI_ASSERT(ans > 0);
50  return ans;
51 }
52 
53 void Nnet::SetNodeName(int32 node_index, const std::string &new_name) {
54  if (!(static_cast<size_t>(node_index) < nodes_.size()))
55  KALDI_ERR << "Invalid node index";
56  if (GetNodeIndex(new_name) != -1)
57  KALDI_ERR << "You cannot rename a node to create a duplicate node name";
58  if (!IsValidName(new_name))
59  KALDI_ERR << "Node name " << new_name << " is not allowed.";
60  node_names_[node_index] = new_name;
61 }
62 
63 const std::vector<std::string> &Nnet::GetNodeNames() const {
64  return node_names_;
65 }
66 
67 const std::vector<std::string> &Nnet::GetComponentNames() const {
68  return component_names_;
69 }
70 
71 std::string Nnet::GetAsConfigLine(int32 node_index, bool include_dim) const {
72  std::ostringstream ans;
73  KALDI_ASSERT(node_index < nodes_.size() &&
74  nodes_.size() == node_names_.size());
75  const NetworkNode &node = nodes_[node_index];
76  const std::string &name = node_names_[node_index];
77  switch (node.node_type) {
78  case kInput:
79  ans << "input-node name=" << name << " dim=" << node.dim;
80  break;
81  case kDescriptor:
82  // assert that it's an output-descriptor, not one describing the input to
83  // a component-node.
84  KALDI_ASSERT(IsOutputNode(node_index));
85  ans << "output-node name=" << name << " input=";
86  node.descriptor.WriteConfig(ans, node_names_);
87  if (include_dim)
88  ans << " dim=" << node.Dim(*this);
89  ans << " objective=" << (node.u.objective_type == kLinear ? "linear" :
90  "quadratic");
91  break;
92  case kComponent:
93  ans << "component-node name=" << name << " component="
94  << component_names_[node.u.component_index] << " input=";
95  KALDI_ASSERT(nodes_[node_index-1].node_type == kDescriptor);
96  nodes_[node_index-1].descriptor.WriteConfig(ans, node_names_);
97  if (include_dim)
98  ans << " input-dim=" << nodes_[node_index-1].Dim(*this)
99  << " output-dim=" << node.Dim(*this);
100  break;
101  case kDimRange:
102  ans << "dim-range-node name=" << name << " input-node="
103  << node_names_[node.u.node_index] << " dim-offset="
104  << node.dim_offset << " dim=" << node.dim;
105  break;
106  default:
107  KALDI_ERR << "Unknown node type.";
108  }
109  return ans.str();
110 }
111 
112 bool Nnet::IsOutputNode(int32 node) const {
113  int32 size = nodes_.size();
114  KALDI_ASSERT(node >= 0 && node < size);
115  return (nodes_[node].node_type == kDescriptor &&
116  (node + 1 == size ||
117  nodes_[node + 1].node_type != kComponent));
118 }
119 
120 bool Nnet::IsInputNode(int32 node) const {
121  int32 size = nodes_.size();
122  KALDI_ASSERT(node >= 0 && node < size);
123  return (nodes_[node].node_type == kInput);
124 }
125 
126 bool Nnet::IsDescriptorNode(int32 node) const {
127  int32 size = nodes_.size();
128  KALDI_ASSERT(node >= 0 && node < size);
129  return (nodes_[node].node_type == kDescriptor);
130 }
131 
132 bool Nnet::IsComponentNode(int32 node) const {
133  int32 size = nodes_.size();
134  KALDI_ASSERT(node >= 0 && node < size);
135  return (nodes_[node].node_type == kComponent);
136 }
137 
138 bool Nnet::IsDimRangeNode(int32 node) const {
139  int32 size = nodes_.size();
140  KALDI_ASSERT(node >= 0 && node < size);
141  return (nodes_[node].node_type == kDimRange);
142 }
143 
144 
146  KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
147  return components_[c];
148 }
149 
151  KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
152  return components_[c];
153 }
154 
155 void Nnet::SetComponent(int32 c, Component *component) {
156  KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
157  delete components_[c];
158  components_[c] = component;
159 }
160 
161 int32 Nnet::AddComponent(const std::string &name,
162  Component *component) {
163  int32 ans = components_.size();
164  KALDI_ASSERT(IsValidName(name) && component != NULL);
165  components_.push_back(component);
166  component_names_.push_back(name);
167  return ans;
168 }
169 
173  int32 size = nodes_.size();
174  KALDI_ASSERT(node >= 0 && node < size);
175  return (node + 1 < size &&
176  nodes_[node].node_type == kDescriptor &&
177  nodes_[node+1].node_type == kComponent);
178 }
179 
180 void Nnet::GetConfigLines(bool include_dim,
181  std::vector<std::string> *config_lines) const {
182  config_lines->clear();
183  for (int32 n = 0; n < NumNodes(); n++)
184  if (!IsComponentInputNode(n))
185  config_lines->push_back(GetAsConfigLine(n, include_dim));
186 
187 }
188 
189 void Nnet::ReadConfig(std::istream &config_is) {
190 
191  std::vector<std::string> lines;
192  // Write into "lines" a config file corresponding to whatever
193  // nodes we currently have. Because the numbering of nodes may
194  // change, it's most convenient to convert to the text representation
195  // and combine the existing and new config lines in that representation.
196  const bool include_dim = false;
197  GetConfigLines(include_dim, &lines);
198 
199  // we'll later regenerate what we need from nodes_ and node_name_ from the
200  // string representation.
201  nodes_.clear();
202  node_names_.clear();
203 
204  int32 num_lines_initial = lines.size();
205 
206  ReadConfigLines(config_is, &lines);
207  // now "lines" will have comments removed and empty lines stripped out
208 
209  std::vector<ConfigLine> config_lines(lines.size());
210 
211  ParseConfigLines(lines, &config_lines);
212 
213  // the next line will possibly remove some elements from "config_lines" so no
214  // node or component is doubly defined, always keeping the second repeat.
215  // Things being doubly defined can happen when a previously existing node or
216  // component is redefined in a new config file.
217  RemoveRedundantConfigLines(num_lines_initial, &config_lines);
218 
219  int32 initial_num_components = components_.size();
220  for (int32 pass = 0; pass <= 1; pass++) {
221  for (size_t i = 0; i < config_lines.size(); i++) {
222  const std::string &first_token = config_lines[i].FirstToken();
223  if (first_token == "component") {
224  if (pass == 0)
225  ProcessComponentConfigLine(initial_num_components,
226  &(config_lines[i]));
227  } else if (first_token == "component-node") {
228  ProcessComponentNodeConfigLine(pass, &(config_lines[i]));
229  } else if (first_token == "input-node") {
230  if (pass == 0)
231  ProcessInputNodeConfigLine(&(config_lines[i]));
232  } else if (first_token == "output-node") {
233  ProcessOutputNodeConfigLine(pass, &(config_lines[i]));
234  } else if (first_token == "dim-range-node") {
235  ProcessDimRangeNodeConfigLine(pass, &(config_lines[i]));
236  } else {
237  KALDI_ERR << "Invalid config-file line ('" << first_token
238  << "' not expected): " << config_lines[i].WholeLine();
239  }
240  }
241  }
242  Check();
243 }
244 
245 
246 // called only on pass 0 of ReadConfig.
248  int32 initial_num_components,
249  ConfigLine *config) {
250  std::string name, type;
251  if (!config->GetValue("name", &name))
252  KALDI_ERR << "Expected field name=<component-name> in config line: "
253  << config->WholeLine();
254  if (!IsToken(name)) // e.g. contains a space.
255  KALDI_ERR << "Component name '" << name << "' is not allowed, in line: "
256  << config->WholeLine();
257  if (!config->GetValue("type", &type))
258  KALDI_ERR << "Expected field type=<component-type> in config line: "
259  << config->WholeLine();
260  Component *new_component = Component::NewComponentOfType(type);
261  if (new_component == NULL)
262  KALDI_ERR << "Unknown component-type '" << type
263  << "' in config file. Check your code version and config.";
264  // the next call will call KALDI_ERR or KALDI_ASSERT and die if something
265  // went wrong.
266  new_component->InitFromConfig(config);
267  int32 index = GetComponentIndex(name);
268  if (index != -1) { // Replacing existing component.
269  if (index >= initial_num_components) {
270  // that index was something we added from this config.
271  KALDI_ERR << "You are adding two components with the same name: '"
272  << name << "'";
273  }
274  delete components_[index];
275  components_[index] = new_component;
276  } else {
277  components_.push_back(new_component);
278  component_names_.push_back(name);
279  }
280  if (config->HasUnusedValues())
281  KALDI_ERR << "Unused values '" << config->UnusedValues()
282  << "' in config line: " << config->WholeLine();
283 }
284 
285 
287  int32 pass,
288  ConfigLine *config) {
289 
290  std::string name;
291  if (!config->GetValue("name", &name))
292  KALDI_ERR << "Expected field name=<component-name> in config line: "
293  << config->WholeLine();
294 
295  std::string input_name = name + std::string("_input");
296  int32 input_node_index = GetNodeIndex(input_name),
297  node_index = GetNodeIndex(name);
298 
299  if (pass == 0) {
300  KALDI_ASSERT(input_node_index == -1 && node_index == -1);
301  // just set up the node types and names for now, we'll properly set them up
302  // on pass 1.
303  nodes_.push_back(NetworkNode(kDescriptor));
304  nodes_.push_back(NetworkNode(kComponent));
305  node_names_.push_back(input_name);
306  node_names_.push_back(name);
307  return;
308  } else {
309  KALDI_ASSERT(input_node_index != -1 && node_index == input_node_index + 1);
310  std::string component_name, input_descriptor;
311  if (!config->GetValue("component", &component_name))
312  KALDI_ERR << "Expected component=<component-name>, in config line: "
313  << config->WholeLine();
314  int32 component_index = GetComponentIndex(component_name);
315  if (component_index == -1)
316  KALDI_ERR << "No component named '" << component_name
317  << "', in config line: " << config->WholeLine();
318  nodes_[node_index].u.component_index = component_index;
319 
320  if (!config->GetValue("input", &input_descriptor))
321  KALDI_ERR << "Expected input=<input-descriptor>, in config line: "
322  << config->WholeLine();
323  std::vector<std::string> tokens;
324  if (!DescriptorTokenize(input_descriptor, &tokens))
325  KALDI_ERR << "Error tokenizing descriptor in config line "
326  << config->WholeLine();
327  std::vector<std::string> node_names_temp;
328  GetSomeNodeNames(&node_names_temp);
329  tokens.push_back("end of input");
330  const std::string *next_token = &(tokens[0]);
331  if (!nodes_[input_node_index].descriptor.Parse(node_names_temp,
332  &next_token))
333  KALDI_ERR << "Error parsing Descriptor in config line: "
334  << config->WholeLine();
335  if (config->HasUnusedValues())
336  KALDI_ERR << "Unused values '" << config->UnusedValues()
337  << " in config line: " << config->WholeLine();
338  }
339 }
340 
341 // called only on pass 0 of ReadConfig.
343  ConfigLine *config) {
344  std::string name;
345  if (!config->GetValue("name", &name))
346  KALDI_ERR << "Expected field name=<input-name> in config line: "
347  << config->WholeLine();
348  int32 dim;
349  if (!config->GetValue("dim", &dim))
350  KALDI_ERR << "Expected field dim=<input-dim> in config line: "
351  << config->WholeLine();
352 
353  if (config->HasUnusedValues())
354  KALDI_ERR << "Unused values '" << config->UnusedValues()
355  << " in config line: " << config->WholeLine();
356 
357  KALDI_ASSERT(GetNodeIndex(name) == -1);
358  if (dim <= 0)
359  KALDI_ERR << "Invalid dimension in config line: " << config->WholeLine();
360 
361  int32 node_index = nodes_.size();
362  nodes_.push_back(NetworkNode(kInput));
363  nodes_[node_index].dim = dim;
364  node_names_.push_back(name);
365 }
366 
367 
369  int32 pass,
370  ConfigLine *config) {
371  std::string name;
372  if (!config->GetValue("name", &name))
373  KALDI_ERR << "Expected field name=<input-name> in config line: "
374  << config->WholeLine();
375  int32 node_index = GetNodeIndex(name);
376  if (pass == 0) {
377  KALDI_ASSERT(node_index == -1);
378  nodes_.push_back(NetworkNode(kDescriptor));
379  node_names_.push_back(name);
380  } else {
381  KALDI_ASSERT(node_index != -1);
382  std::string input_descriptor;
383  if (!config->GetValue("input", &input_descriptor))
384  KALDI_ERR << "Expected input=<input-descriptor>, in config line: "
385  << config->WholeLine();
386  std::vector<std::string> tokens;
387  if (!DescriptorTokenize(input_descriptor, &tokens))
388  KALDI_ERR << "Error tokenizing descriptor in config line "
389  << config->WholeLine();
390  tokens.push_back("end of input");
391  // if the following fails it will die.
392  std::vector<std::string> node_names_temp;
393  GetSomeNodeNames(&node_names_temp);
394  const std::string *next_token = &(tokens[0]);
395  if (!nodes_[node_index].descriptor.Parse(node_names_temp, &next_token))
396  KALDI_ERR << "Error parsing descriptor (input=...) in config line "
397  << config->WholeLine();
398  std::string objective_type;
399  if (config->GetValue("objective", &objective_type)) {
400  if (objective_type == "linear") {
401  nodes_[node_index].u.objective_type = kLinear;
402  } else if (objective_type == "quadratic") {
403  nodes_[node_index].u.objective_type = kQuadratic;
404  } else {
405  KALDI_ERR << "Invalid objective type: " << objective_type;
406  }
407  } else {
408  // the default objective type is linear. This is what we use
409  // for softmax objectives; the LogSoftmaxLayer is included as the
410  // last layer, in this case.
411  nodes_[node_index].u.objective_type = kLinear;
412  }
413  if (config->HasUnusedValues())
414  KALDI_ERR << "Unused values '" << config->UnusedValues()
415  << " in config line: " << config->WholeLine();
416  }
417 }
418 
419 
421  int32 pass,
422  ConfigLine *config) {
423  std::string name;
424  if (!config->GetValue("name", &name))
425  KALDI_ERR << "Expected field name=<input-name> in config line: "
426  << config->WholeLine();
427  int32 node_index = GetNodeIndex(name);
428  if (pass == 0) {
429  KALDI_ASSERT(node_index == -1);
430  nodes_.push_back(NetworkNode(kDimRange));
431  node_names_.push_back(name);
432  } else {
433  KALDI_ASSERT(node_index != -1);
434  std::string input_node_name;
435  if (!config->GetValue("input-node", &input_node_name))
436  KALDI_ERR << "Expected input-node=<input-node-name>, in config line: "
437  << config->WholeLine();
439  if (!config->GetValue("dim", &dim))
440  KALDI_ERR << "Expected dim=<feature-dim>, in config line: "
441  << config->WholeLine();
442  if (!config->GetValue("dim-offset", &dim_offset))
443  KALDI_ERR << "Expected dim-offset=<dimension-offset>, in config line: "
444  << config->WholeLine();
445 
446  int32 input_node_index = GetNodeIndex(input_node_name);
447  if (input_node_index == -1 ||
448  !(nodes_[input_node_index].node_type == kComponent ||
449  nodes_[input_node_index].node_type == kInput))
450  KALDI_ERR << "invalid input-node " << input_node_name
451  << ": " << config->WholeLine();
452 
453  if (config->HasUnusedValues())
454  KALDI_ERR << "Unused values '" << config->UnusedValues()
455  << " in config line: " << config->WholeLine();
456 
457  NetworkNode &node = nodes_[node_index];
459  node.u.node_index = input_node_index;
460  node.dim = dim;
461  node.dim_offset = dim_offset;
462  }
463 }
464 
465 
466 int32 Nnet::GetNodeIndex(const std::string &node_name) const {
467  size_t size = node_names_.size();
468  for (size_t i = 0; i < size; i++)
469  if (node_names_[i] == node_name)
470  return static_cast<int32>(i);
471  return -1;
472 }
473 
474 int32 Nnet::GetComponentIndex(const std::string &component_name) const {
475  size_t size = component_names_.size();
476  for (size_t i = 0; i < size; i++)
477  if (component_names_[i] == component_name)
478  return static_cast<int32>(i);
479  return -1;
480 }
481 
482 
483 // note: the input to this function is a config generated from the nnet,
484 // containing the node info, concatenated with a config provided by the user.
485 //static
487  std::vector<ConfigLine> *config_lines) {
488  int32 num_lines = config_lines->size();
489  KALDI_ASSERT(num_lines_initial <= num_lines);
490  // node names and component names live in different namespaces.
491  unordered_map<std::string, int32, StringHasher> node_name_to_most_recent_line;
492  unordered_set<std::string, StringHasher> component_names;
493  typedef unordered_map<std::string, int32, StringHasher>::iterator IterType;
494 
495  std::vector<bool> to_remove(num_lines, false);
496  for (int32 line = 0; line < num_lines; line++) {
497  ConfigLine &config_line = (*config_lines)[line];
498  std::string name;
499  if (!config_line.GetValue("name", &name))
500  KALDI_ERR << "Config line has no field 'name=xxx': "
501  << config_line.WholeLine();
502  if (!IsValidName(name))
503  KALDI_ERR << "Name '" << name << "' is not allowable, in line: "
504  << config_line.WholeLine();
505  if (config_line.FirstToken() == "component") {
506  // a line starting with "component"... components live in their own
507  // namespace. No repeats are allowed because we never wrote them
508  // to the config generated from the nnet.
509  if (!component_names.insert(name).second) {
510  // we could not insert it because it was already there.
511  KALDI_ERR << "Component name " << name
512  << " appears twice in the same config file.";
513  }
514  } else {
515  // the line defines some sort of network node, e.g. component-node.
516  IterType iter = node_name_to_most_recent_line.find(name);
517  if (iter != node_name_to_most_recent_line.end()) {
518  // name is repeated.
519  int32 prev_line = iter->second;
520  if (prev_line >= num_lines_initial) {
521  // user-provided config contained repeat of node with this name.
522  KALDI_ERR << "Node name " << name
523  << " appears twice in the same config file.";
524  }
525  // following assert checks that the config-file generated
526  // from an actual nnet does not contain repeats.. that
527  // would be a bug so check it with assert.
528  KALDI_ASSERT(line >= num_lines_initial);
529  to_remove[prev_line] = true;
530  }
531  node_name_to_most_recent_line[name] = line;
532  }
533  }
534  // Now remove any lines with to_remove[i] = true.
535  std::vector<ConfigLine> config_lines_out;
536  config_lines_out.reserve(num_lines);
537  for (int32 i = 0; i < num_lines; i++) {
538  if (!to_remove[i])
539  config_lines_out.push_back((*config_lines)[i]);
540  }
541  config_lines->swap(config_lines_out);
542 }
543 
544 // copy constructor.
546  node_type(other.node_type),
547  descriptor(other.descriptor),
548  dim(other.dim),
549  dim_offset(other.dim_offset) {
550  u.component_index = other.u.component_index;
551 }
552 
553 
555  for (size_t i = 0; i < components_.size(); i++)
556  delete components_[i];
557  component_names_.clear();
558  components_.clear();
559  node_names_.clear();
560  nodes_.clear();
561 }
562 
564  std::vector<std::string> *modified_node_names) const {
565  modified_node_names->resize(node_names_.size());
566  const std::string invalid_name = "**";
567  size_t size = node_names_.size();
568  for (size_t i = 0; i < size; i++) {
569  if (nodes_[i].node_type == kComponent ||
570  nodes_[i].node_type == kInput ||
571  nodes_[i].node_type == kDimRange) {
572  (*modified_node_names)[i] = node_names_[i];
573  } else {
574  (*modified_node_names)[i] = invalid_name;
575  }
576  }
577 }
578 
579 void Nnet::Swap(Nnet *other) {
580  component_names_.swap(other->component_names_);
581  components_.swap(other->components_);
582  node_names_.swap(other->node_names_);
583  nodes_.swap(other->nodes_);
584 }
585 
586 void Nnet::Read(std::istream &is, bool binary) {
587  Destroy();
588  int first_char = PeekToken(is, binary);
589  if (first_char == 'T') {
590  // This branch is to allow '.mdl' files (containing a TransitionModel
591  // and then an AmNnetSimple) to be read where .raw files (containing
592  // just an Nnet) would be expected. This is often convenient.
593  TransitionModel temp_trans_model;
594  temp_trans_model.Read(is, binary);
595  AmNnetSimple temp_am_nnet;
596  temp_am_nnet.Read(is, binary);
597  temp_am_nnet.GetNnet().Swap(this);
598  return;
599  }
600 
601  ExpectToken(is, binary, "<Nnet3>");
602  std::ostringstream config_file_out;
603  std::string cur_line;
604  getline(is, cur_line); // Eat up a single newline.
605  if (!(cur_line == "" || cur_line == "\r"))
606  KALDI_ERR << "Expected newline in config file, got " << cur_line;
607  while (getline(is, cur_line)) {
608  // config-file part of file is terminated by an empty line.
609  if (cur_line == "" || cur_line == "\r")
610  break;
611  config_file_out << cur_line << std::endl;
612  }
613  // Now we read the Components; later we try to parse the config_lines.
614  ExpectToken(is, binary, "<NumComponents>");
615  int32 num_components;
616  ReadBasicType(is, binary, &num_components);
617  KALDI_ASSERT(num_components >= 0 && num_components < 100000);
618  components_.resize(num_components, NULL);
619  component_names_.resize(num_components);
620  for (int32 c = 0; c < num_components; c++) {
621  ExpectToken(is, binary, "<ComponentName>");
622  ReadToken(is, binary, &(component_names_[c]));
623  components_[c] = Component::ReadNew(is, binary);
624  }
625  ExpectToken(is, binary, "</Nnet3>");
626  std::istringstream config_file_in(config_file_out.str());
627  this->ReadConfig(config_file_in);
628 }
629 
630 void Nnet::Write(std::ostream &os, bool binary) const {
631  WriteToken(os, binary, "<Nnet3>");
632  os << std::endl;
633  std::vector<std::string> config_lines;
634  const bool include_dim = false;
635  GetConfigLines(include_dim, &config_lines);
636  for (size_t i = 0; i < config_lines.size(); i++) {
637  KALDI_ASSERT(!config_lines[i].empty());
638  os << config_lines[i] << std::endl;
639  }
640  // A blank line terminates the config-like section of the file.
641  os << std::endl;
642  // Now write the Components
643  int32 num_components = components_.size();
644  WriteToken(os, binary, "<NumComponents>");
645  WriteBasicType(os, binary, num_components);
646  if (!binary)
647  os << std::endl;
648  for (int32 c = 0; c < num_components; c++) {
649  WriteToken(os, binary, "<ComponentName>");
650  WriteToken(os, binary, component_names_[c]);
651  components_[c]->Write(os, binary);
652  if (!binary)
653  os << std::endl;
654  }
655  WriteToken(os, binary, "</Nnet3>");
656 }
657 
659  int32 ans = 1;
660  for (int32 n = 0; n < NumNodes(); n++) {
661  const NetworkNode &node = nodes_[n];
662  if (node.node_type == kDescriptor)
663  ans = Lcm(ans, node.descriptor.Modulus());
664  }
665  return ans;
666 }
667 
668 
669 int32 Nnet::InputDim(const std::string &input_name) const {
670  int32 n = GetNodeIndex(input_name);
671  if (n == -1) return -1;
672  const NetworkNode &node = nodes_[n];
673  if (node.node_type != kInput) return -1;
674  return node.dim;
675 }
676 
677 int32 Nnet::OutputDim(const std::string &input_name) const {
678  int32 n = GetNodeIndex(input_name);
679  if (n == -1 || !IsOutputNode(n)) return -1;
680  const NetworkNode &node = nodes_[n];
681  return node.Dim(*this);
682 }
683 
684 const std::string& Nnet::GetNodeName(int32 node_index) const {
685  KALDI_ASSERT(static_cast<size_t>(node_index) < node_names_.size());
686  return node_names_[node_index];
687 }
688 
689 const std::string& Nnet::GetComponentName(int32 component_index) const {
690  KALDI_ASSERT(static_cast<size_t>(component_index) < component_names_.size());
691  return component_names_[component_index];
692 }
693 
694 void Nnet::Check(bool warn_for_orphans) const {
695  int32 num_nodes = nodes_.size(),
696  num_input_nodes = 0,
697  num_output_nodes = 0;
698  KALDI_ASSERT(num_nodes != 0);
699  for (int32 n = 0; n < num_nodes; n++) {
700  const NetworkNode &node = nodes_[n];
701  std::string node_name = node_names_[n];
702  KALDI_ASSERT(GetNodeIndex(node_name) == n);
703  switch (node.node_type) {
704  case kInput:
705  KALDI_ASSERT(node.dim > 0);
706  num_input_nodes++;
707  break;
708  case kDescriptor: {
709  if (IsOutputNode(n))
710  num_output_nodes++;
711  std::vector<int32> node_deps;
712  node.descriptor.GetNodeDependencies(&node_deps);
713  SortAndUniq(&node_deps);
714  for (size_t i = 0; i < node_deps.size(); i++) {
715  int32 src_node = node_deps[i];
716  KALDI_ASSERT(src_node >= 0 && src_node < num_nodes);
717  NodeType src_type = nodes_[src_node].node_type;
718  if (src_type != kInput && src_type != kDimRange &&
719  src_type != kComponent)
720  KALDI_ERR << "Invalid source node type in Descriptor: source node "
721  << node_names_[src_node];
722  }
723  break;
724  }
725  case kComponent: {
726  KALDI_ASSERT(n > 0 && nodes_[n-1].node_type == kDescriptor);
727  const NetworkNode &src_node = nodes_[n-1];
728  const Component *c = GetComponent(node.u.component_index);
729  int32 src_dim, input_dim = c->InputDim();
730  try {
731  src_dim = src_node.Dim(*this);
732  } catch (...) {
733  KALDI_ERR << "Error in Descriptor for network-node "
734  << node_name << " (see error above)";
735  }
736  if (src_dim != input_dim) {
737  KALDI_ERR << "Dimension mismatch for network-node "
738  << node_name << ": input-dim "
739  << src_dim << " versus component-input-dim "
740  << input_dim;
741  }
742  break;
743  }
744  case kDimRange: {
745  int32 input_node = node.u.node_index;
746  KALDI_ASSERT(input_node >= 0 && input_node < num_nodes);
747  NodeType input_type = nodes_[input_node].node_type;
748  if (input_type != kInput && input_type != kComponent)
749  KALDI_ERR << "Invalid source node type in DimRange node: source node "
750  << node_names_[input_node];
751  int32 input_dim = nodes_[input_node].Dim(*this);
752  if (!(node.dim > 0 && node.dim_offset >= 0 &&
753  node.dim + node.dim_offset <= input_dim)) {
754  KALDI_ERR << "Invalid node dimensions for DimRange node: " << node_name
755  << ": input-dim=" << input_dim << ", dim=" << node.dim
756  << ", dim-offset=" << node.dim_offset;
757  }
758  break;
759  }
760  default:
761  KALDI_ERR << "Invalid node type for node " << node_name;
762  }
763  }
764 
765  int32 num_components = components_.size();
766  for (int32 c = 0; c < num_components; c++) {
767  const std::string &component_name = component_names_[c];
768  KALDI_ASSERT(GetComponentIndex(component_name) == c &&
769  "Duplicate component names?");
770  }
771  KALDI_ASSERT(num_input_nodes > 0);
772  KALDI_ASSERT(num_output_nodes > 0);
773 
774 
775  if (warn_for_orphans) {
776  std::vector<int32> orphans;
777  FindOrphanComponents(*this, &orphans);
778  for (size_t i = 0; i < orphans.size(); i++) {
779  KALDI_WARN << "Component " << GetComponentName(orphans[i])
780  << " is never used by any node.";
781  }
782  FindOrphanNodes(*this, &orphans);
783  for (size_t i = 0; i < orphans.size(); i++) {
784  if (!IsComponentInputNode(orphans[i])) {
785  // There is no point warning about component-input nodes, since the
786  // warning will be printed for the corresponding component nodes.. a
787  // duplicate warning might be confusing to the user, as the
788  // component-input nodes are implicit and usually hidden from users.
789  KALDI_WARN << "Node " << GetNodeName(orphans[i])
790  << " is never used to compute any output.";
791  }
792  }
793  }
794 }
795 
796 // copy constructor
797 Nnet::Nnet(const Nnet &nnet):
798  component_names_(nnet.component_names_),
799  components_(nnet.components_.size()),
800  node_names_(nnet.node_names_),
801  nodes_(nnet.nodes_) {
802  for (size_t i = 0; i < components_.size(); i++)
803  components_[i] = nnet.components_[i]->Copy();
804  Check();
805 }
806 
807 Nnet& Nnet::operator =(const Nnet &nnet) {
808  if (this == &nnet)
809  return *this;
810  Destroy();
812  components_.resize(nnet.components_.size());
813  node_names_ = nnet.node_names_;
814  nodes_ = nnet.nodes_;
815  for (size_t i = 0; i < components_.size(); i++)
816  components_[i] = nnet.components_[i]->Copy();
817  Check();
818  return *this;
819 }
820 
821 std::string Nnet::Info() const {
822  std::ostringstream os;
823 
824  if(IsSimpleNnet(*this)) {
825  int32 left_context, right_context;
826  ComputeSimpleNnetContext(*this, &left_context, &right_context);
827  os << "left-context: " << left_context << "\n";
828  os << "right-context: " << right_context << "\n";
829  }
830  os << "num-parameters: " << NumParameters(*this) << "\n";
831  os << "modulus: " << this->Modulus() << "\n";
832  std::vector<std::string> config_lines;
833  bool include_dim = true;
834  GetConfigLines(include_dim, &config_lines);
835  for (size_t i = 0; i < config_lines.size(); i++)
836  os << config_lines[i] << "\n";
837  // Get component info.
838  for (size_t i = 0; i < components_.size(); i++)
839  os << "component name=" << component_names_[i]
840  << " type=" << components_[i]->Info() << "\n";
841  return os.str();
842 }
843 
845  std::vector<int32> orphan_components;
846  FindOrphanComponents(*this, &orphan_components);
847  KALDI_LOG << "Removing " << orphan_components.size()
848  << " orphan components.";
849  if (orphan_components.empty())
850  return;
851  int32 old_num_components = components_.size(),
852  new_num_components = 0;
853  std::vector<int32> old2new_map(old_num_components, 0);
854  for (size_t i = 0; i < orphan_components.size(); i++)
855  old2new_map[orphan_components[i]] = -1;
856  std::vector<Component*> new_components;
857  std::vector<std::string> new_component_names;
858  for (int32 c = 0; c < old_num_components; c++) {
859  if (old2new_map[c] != -1) {
860  old2new_map[c] = new_num_components++;
861  new_components.push_back(components_[c]);
862  new_component_names.push_back(component_names_[c]);
863  } else {
864  delete components_[c];
865  components_[c] = NULL;
866  }
867  }
868  for (int32 n = 0; n < NumNodes(); n++) {
869  if (IsComponentNode(n)) {
870  int32 old_c = nodes_[n].u.component_index,
871  new_c = old2new_map[old_c];
872  KALDI_ASSERT(new_c >= 0);
873  nodes_[n].u.component_index = new_c;
874  }
875  }
876  components_ = new_components;
877  component_names_ = new_component_names;
878  Check();
879 }
880 
881 void Nnet::RemoveSomeNodes(const std::vector<int32> &nodes_to_remove) {
882  if (nodes_to_remove.empty())
883  return;
884  int32 old_num_nodes = nodes_.size(),
885  new_num_nodes = 0;
886  std::vector<int32> old2new_map(old_num_nodes, 0);
887  for (size_t i = 0; i < nodes_to_remove.size(); i++)
888  old2new_map[nodes_to_remove[i]] = -1;
889  std::vector<NetworkNode> new_nodes;
890  std::vector<std::string> new_node_names;
891  for (int32 n = 0; n < old_num_nodes; n++) {
892  if (old2new_map[n] != -1) {
893  old2new_map[n] = new_num_nodes++;
894  new_nodes.push_back(nodes_[n]);
895  new_node_names.push_back(node_names_[n]);
896  }
897  }
898  for (int32 n = 0; n < new_num_nodes; n++) {
899  if (new_nodes[n].node_type == kDescriptor) {
900  // we need to renumber the node indexes inside the descriptor. It's
901  // easiest to do this by converting back and forth to text format. This
902  // is inefficient, of course, but these graphs are typically quite small.
903  std::ostringstream os;
904  new_nodes[n].descriptor.WriteConfig(os, node_names_);
905  std::vector<std::string> tokens;
906  DescriptorTokenize(os.str(), &tokens);
907  KALDI_ASSERT(!tokens.empty());
908  tokens.push_back("end of input");
909  const std::string *token = &(tokens[0]);
910  Descriptor new_descriptor;
911  // this should work; if it doesn't, there was a programming error.
912  if (!new_nodes[n].descriptor.Parse(new_node_names, &token)) {
913  KALDI_ERR << "Code error removing orphan nodes.";
914  }
915  } else if (new_nodes[n].node_type == kDimRange) {
916  int32 old_node_index = new_nodes[n].u.node_index,
917  new_node_index = old2new_map[old_node_index];
918  KALDI_ASSERT(new_node_index >= 0 && new_node_index <= new_num_nodes);
919  new_nodes[n].u.node_index = new_node_index;
920  }
921  }
922  nodes_ = new_nodes;
923  node_names_ = new_node_names;
924  bool warn_for_orphans = false;
925  // don't warn about orphans, because at this stage we may have
926  // orphan components that will later be removed by calling
927  // RemoveOrphanComponents().
928  Check(warn_for_orphans);
929 }
930 
931 
932 void Nnet::RemoveOrphanNodes(bool remove_orphan_inputs) {
933  std::vector<int32> orphan_nodes;
934  FindOrphanNodes(*this, &orphan_nodes);
935  if (!remove_orphan_inputs)
936  for (int32 i = 0; i < orphan_nodes.size(); i++)
937  if (IsInputNode(orphan_nodes[i]))
938  orphan_nodes.erase(orphan_nodes.begin() + i);
939  // For each component-node, its component-input node (which is kind of a
940  // "hidden" node) would be included in 'orphan_nodes', but for diagnostic
941  // purposes we want to exclude these from 'num_nodes_removed' to avoid
942  // confusing users.
943  int32 num_nodes_removed = 0;
944  for (int32 i = 0; i < orphan_nodes.size(); i++)
945  if (!IsComponentInputNode(orphan_nodes[i]))
946  num_nodes_removed++;
947  RemoveSomeNodes(orphan_nodes);
948  KALDI_LOG << "Removed " << num_nodes_removed << " orphan nodes.";
949 }
950 
952  // resets random-number generators for all random
953  // components.
954  for (int32 c = 0; c < NumComponents(); c++) {
955  RandomComponent *rc = dynamic_cast<RandomComponent*>(GetComponent(c));
956  if (rc != NULL)
957  rc->ResetGenerator();
958  }
959 }
960 
961 } // namespace nnet3
962 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 InputDim(const std::string &input_name) const
Definition: nnet-nnet.cc:669
int32 NumNodes() const
Definition: nnet-nnet.h:126
const std::string & FirstToken() const
Definition: text-utils.h:228
void GetNodeDependencies(std::vector< int32 > *node_indexes) const
const std::string WholeLine()
Definition: text-utils.h:230
int32 AddComponent(const std::string &name, Component *component)
Adds a new component with the given name, which should not be the same as any existing component name...
Definition: nnet-nnet.cc:161
void ResetGenerators()
Definition: nnet-nnet.cc:951
void Write(std::ostream &ostream, bool binary) const
Definition: nnet-nnet.cc:630
void ReadConfig(std::istream &config_file)
Definition: nnet-nnet.cc:189
void FindOrphanComponents(const Nnet &nnet, std::vector< int32 > *components)
This function finds a list of components that are never used, and outputs the integer comopnent index...
Definition: nnet-utils.cc:591
void GetSomeNodeNames(std::vector< std::string > *modified_node_names) const
Definition: nnet-nnet.cc:563
const std::string & GetNodeName(int32 node_index) const
returns individual node name.
Definition: nnet-nnet.cc:684
bool Parse(const std::vector< std::string > &node_names, const std::string **next_token)
Abstract base-class for neural-net components.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
void FindOrphanNodes(const Nnet &nnet, std::vector< int32 > *nodes)
This function finds a list of nodes that are never used to compute any output, and outputs the intege...
Definition: nnet-utils.cc:607
bool DescriptorTokenize(const std::string &input, std::vector< std::string > *tokens)
This function tokenizes input when parsing Descriptor configuration values.
Definition: nnet-parse.cc:30
bool IsInputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kInput.
Definition: nnet-nnet.cc:120
static void RemoveRedundantConfigLines(int32 num_lines_initial, std::vector< ConfigLine > *config_lines)
Definition: nnet-nnet.cc:486
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
void SetComponent(int32 c, Component *component)
Replace the component indexed by c with a new component.
Definition: nnet-nnet.cc:155
virtual int32 OutputDim() const =0
Returns output-dimension of this component.
bool IsComponentNode(int32 node) const
Returns true if this is a component node, meaning that it is of type kComponent.
Definition: nnet-nnet.cc:132
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq&#39;s (removes duplicates) from a vector.
Definition: stl-utils.h:39
void RemoveOrphanComponents()
Definition: nnet-nnet.cc:844
ObjectiveType objective_type
Definition: nnet-nnet.h:97
bool IsValidName(const std::string &name)
Returns true if &#39;name&#39; would be a valid name for a component or node in a nnet3Nnet.
Definition: text-utils.cc:553
const Nnet & GetNnet() const
std::string GetAsConfigLine(int32 node_index, bool include_dim) const
Definition: nnet-nnet.cc:71
void Read(std::istream &is, bool binary)
void SetNodeName(int32 node_index, const std::string &new_name)
This can be used to modify invidual node names.
Definition: nnet-nnet.cc:53
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
int32 OutputDim(const std::string &output_name) const
Definition: nnet-nnet.cc:677
I Lcm(I m, I n)
Returns the least common multiple of two integers.
Definition: kaldi-math.h:318
This file contains some miscellaneous functions dealing with class Nnet.
bool IsToken(const std::string &token)
Returns true if "token" is nonempty, and all characters are printable and whitespace-free.
Definition: text-utils.cc:105
This file contains declarations of components that are "simple", meaning they don&#39;t care about the in...
std::string Info() const
returns some human-readable information about the network, mostly for debugging purposes.
Definition: nnet-nnet.cc:821
int32 Modulus() const
[Relevant for clockwork RNNs and similar].
Definition: nnet-nnet.cc:658
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
void ProcessOutputNodeConfigLine(int32 pass, ConfigLine *config)
Definition: nnet-nnet.cc:368
std::string UnusedValues() const
returns e.g.
Definition: text-utils.cc:518
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
Definition: nnet-utils.cc:359
void ParseConfigLines(const std::vector< std::string > &lines, std::vector< ConfigLine > *config_lines)
Definition: nnet-parse.cc:224
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
int32 Dim(const Nnet &nnet) const
Definition: nnet-nnet.cc:33
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146
void Read(std::istream &istream, bool binary)
Definition: nnet-nnet.cc:586
void Read(std::istream &is, bool binary)
struct rnnlm::@11::@12 n
void RemoveOrphanNodes(bool remove_orphan_inputs=false)
Definition: nnet-nnet.cc:932
int32 GetComponentIndex(const std::string &node_name) const
returns index associated with this component name, or -1 if no such index.
Definition: nnet-nnet.cc:474
void GetConfigLines(bool include_dim, std::vector< std::string > *config_lines) const
Definition: nnet-nnet.cc:180
void ProcessComponentConfigLine(int32 initial_num_components, ConfigLine *config)
Definition: nnet-nnet.cc:247
#define KALDI_ERR
Definition: kaldi-error.h:147
static Component * ReadNew(std::istream &is, bool binary)
Read component from stream (works out its type). Dies on error.
#define KALDI_WARN
Definition: kaldi-error.h:150
const std::string & GetComponentName(int32 component_index) const
returns individual component name.
Definition: nnet-nnet.cc:689
void ReadConfigLines(std::istream &is, std::vector< std::string > *lines)
This function reads in a config file and *appends* its contents to a vector of lines; it is responsib...
Definition: text-utils.cc:564
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
const std::vector< std::string > & GetComponentNames() const
returns vector of component names (needed by some parsing code, for instance).
Definition: nnet-nnet.cc:67
int PeekToken(std::istream &is, bool binary)
PeekToken will return the first character of the next token, or -1 if end of file.
Definition: io-funcs.cc:170
NetworkNode is used to represent, three types of thing: either an input of the network (which pretty ...
Definition: nnet-nnet.h:81
Component * GetComponent(int32 c)
Return component indexed c. Not a copy; not owned by caller.
Definition: nnet-nnet.cc:150
void RemoveSomeNodes(const std::vector< int32 > &nodes_to_remove)
Definition: nnet-nnet.cc:881
void ProcessDimRangeNodeConfigLine(int32 pass, ConfigLine *config)
Definition: nnet-nnet.cc:420
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
Definition: nnet-utils.cc:52
int32 NumComponents() const
Definition: nnet-nnet.h:124
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing=&#39;a b c&#39; baz="a b c d=&#39;a b&#39; e" and giving you access to the fields, in this case.
Definition: text-utils.h:205
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool HasUnusedValues() const
Definition: text-utils.cc:510
bool GetValue(const std::string &key, std::string *value)
Definition: text-utils.cc:427
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
void Check(bool warn_for_orphans=true) const
Checks the neural network for validity (dimension matches and various other requirements).
Definition: nnet-nnet.cc:694
void ProcessInputNodeConfigLine(ConfigLine *config)
Definition: nnet-nnet.cc:342
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
virtual int32 InputDim() const =0
Returns input-dimension of this component.
static Component * NewComponentOfType(const std::string &type)
Returns a new Component of the given type e.g.
#define KALDI_LOG
Definition: kaldi-error.h:153
bool IsDescriptorNode(int32 node) const
Returns true if this is a descriptor node, meaning that it is of type kDescriptor.
Definition: nnet-nnet.cc:126
Nnet & operator=(const Nnet &nnet)
Definition: nnet-nnet.cc:807
bool IsDimRangeNode(int32 node) const
Returns true if this is a dim-range node, meaning that it is of type kDimRange.
Definition: nnet-nnet.cc:138
NetworkNode(NodeType nt=kNone)
Definition: nnet-nnet.h:107
void ProcessComponentNodeConfigLine(int32 pass, ConfigLine *config)
Definition: nnet-nnet.cc:286
const std::vector< std::string > & GetNodeNames() const
returns vector of node names (needed by some parsing code, for instance).
Definition: nnet-nnet.cc:63
int32 Dim(const Nnet &nnet) const
virtual void InitFromConfig(ConfigLine *cfl)=0
Initialize, from a ConfigLine object.
void Swap(Nnet *other)
Definition: nnet-nnet.cc:579
bool IsComponentInputNode(int32 node) const
Returns true if this is component-input node, i.e.
Definition: nnet-nnet.cc:172