54 if (!(static_cast<size_t>(node_index) < nodes_.size()))
56 if (GetNodeIndex(new_name) != -1)
57 KALDI_ERR <<
"You cannot rename a node to create a duplicate node name";
59 KALDI_ERR <<
"Node name " << new_name <<
" is not allowed.";
68 return component_names_;
72 std::ostringstream ans;
74 nodes_.size() == node_names_.size());
76 const std::string &name = node_names_[
node_index];
77 switch (node.node_type) {
79 ans <<
"input-node name=" << name <<
" dim=" << node.dim;
85 ans <<
"output-node name=" << name <<
" input=";
86 node.descriptor.WriteConfig(ans, node_names_);
88 ans <<
" dim=" << node.Dim(*
this);
89 ans <<
" objective=" << (node.u.objective_type ==
kLinear ?
"linear" :
93 ans <<
"component-node name=" << name <<
" component=" 94 << component_names_[node.u.component_index] <<
" input=";
96 nodes_[node_index-1].descriptor.WriteConfig(ans, node_names_);
98 ans <<
" input-dim=" << nodes_[node_index-1].Dim(*
this)
99 <<
" output-dim=" << node.Dim(*
this);
102 ans <<
"dim-range-node name=" << name <<
" input-node=" 103 << node_names_[node.u.node_index] <<
" dim-offset=" 104 << node.dim_offset <<
" dim=" << node.dim;
113 int32 size = nodes_.size();
121 int32 size = nodes_.size();
127 int32 size = nodes_.size();
133 int32 size = nodes_.size();
139 int32 size = nodes_.size();
146 KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
147 return components_[c];
151 KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
152 return components_[c];
156 KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
157 delete components_[c];
158 components_[c] = component;
163 int32 ans = components_.size();
165 components_.push_back(component);
166 component_names_.push_back(name);
173 int32 size = nodes_.size();
175 return (node + 1 < size &&
181 std::vector<std::string> *config_lines)
const {
182 config_lines->clear();
183 for (
int32 n = 0;
n < NumNodes();
n++)
184 if (!IsComponentInputNode(
n))
185 config_lines->push_back(GetAsConfigLine(
n, include_dim));
191 std::vector<std::string> lines;
196 const bool include_dim =
false;
197 GetConfigLines(include_dim, &lines);
204 int32 num_lines_initial = lines.size();
209 std::vector<ConfigLine> config_lines(lines.size());
217 RemoveRedundantConfigLines(num_lines_initial, &config_lines);
219 int32 initial_num_components = components_.size();
220 for (
int32 pass = 0; pass <= 1; pass++) {
221 for (
size_t i = 0;
i < config_lines.size();
i++) {
222 const std::string &first_token = config_lines[
i].FirstToken();
223 if (first_token ==
"component") {
225 ProcessComponentConfigLine(initial_num_components,
227 }
else if (first_token ==
"component-node") {
228 ProcessComponentNodeConfigLine(pass, &(config_lines[
i]));
229 }
else if (first_token ==
"input-node") {
231 ProcessInputNodeConfigLine(&(config_lines[
i]));
232 }
else if (first_token ==
"output-node") {
233 ProcessOutputNodeConfigLine(pass, &(config_lines[
i]));
234 }
else if (first_token ==
"dim-range-node") {
235 ProcessDimRangeNodeConfigLine(pass, &(config_lines[
i]));
237 KALDI_ERR <<
"Invalid config-file line ('" << first_token
238 <<
"' not expected): " << config_lines[
i].WholeLine();
248 int32 initial_num_components,
250 std::string name, type;
251 if (!config->
GetValue(
"name", &name))
252 KALDI_ERR <<
"Expected field name=<component-name> in config line: " 255 KALDI_ERR <<
"Component name '" << name <<
"' is not allowed, in line: " 257 if (!config->
GetValue(
"type", &type))
258 KALDI_ERR <<
"Expected field type=<component-type> in config line: " 261 if (new_component == NULL)
262 KALDI_ERR <<
"Unknown component-type '" << type
263 <<
"' in config file. Check your code version and config.";
267 int32 index = GetComponentIndex(name);
269 if (index >= initial_num_components) {
271 KALDI_ERR <<
"You are adding two components with the same name: '" 274 delete components_[index];
275 components_[index] = new_component;
277 components_.push_back(new_component);
278 component_names_.push_back(name);
282 <<
"' in config line: " << config->
WholeLine();
291 if (!config->
GetValue(
"name", &name))
292 KALDI_ERR <<
"Expected field name=<component-name> in config line: " 295 std::string input_name = name + std::string(
"_input");
296 int32 input_node_index = GetNodeIndex(input_name),
305 node_names_.push_back(input_name);
306 node_names_.push_back(name);
310 std::string component_name, input_descriptor;
311 if (!config->
GetValue(
"component", &component_name))
312 KALDI_ERR <<
"Expected component=<component-name>, in config line: " 315 if (component_index == -1)
316 KALDI_ERR <<
"No component named '" << component_name
317 <<
"', in config line: " << config->
WholeLine();
320 if (!config->
GetValue(
"input", &input_descriptor))
321 KALDI_ERR <<
"Expected input=<input-descriptor>, in config line: " 323 std::vector<std::string> tokens;
325 KALDI_ERR <<
"Error tokenizing descriptor in config line " 327 std::vector<std::string> node_names_temp;
328 GetSomeNodeNames(&node_names_temp);
329 tokens.push_back(
"end of input");
330 const std::string *next_token = &(tokens[0]);
333 KALDI_ERR <<
"Error parsing Descriptor in config line: " 337 <<
" in config line: " << config->
WholeLine();
345 if (!config->
GetValue(
"name", &name))
346 KALDI_ERR <<
"Expected field name=<input-name> in config line: " 350 KALDI_ERR <<
"Expected field dim=<input-dim> in config line: " 355 <<
" in config line: " << config->
WholeLine();
364 node_names_.push_back(name);
372 if (!config->
GetValue(
"name", &name))
373 KALDI_ERR <<
"Expected field name=<input-name> in config line: " 379 node_names_.push_back(name);
382 std::string input_descriptor;
383 if (!config->
GetValue(
"input", &input_descriptor))
384 KALDI_ERR <<
"Expected input=<input-descriptor>, in config line: " 386 std::vector<std::string> tokens;
388 KALDI_ERR <<
"Error tokenizing descriptor in config line " 390 tokens.push_back(
"end of input");
392 std::vector<std::string> node_names_temp;
393 GetSomeNodeNames(&node_names_temp);
394 const std::string *next_token = &(tokens[0]);
395 if (!nodes_[node_index].
descriptor.
Parse(node_names_temp, &next_token))
396 KALDI_ERR <<
"Error parsing descriptor (input=...) in config line " 399 if (config->
GetValue(
"objective", &objective_type)) {
400 if (objective_type ==
"linear") {
402 }
else if (objective_type ==
"quadratic") {
415 <<
" in config line: " << config->
WholeLine();
424 if (!config->
GetValue(
"name", &name))
425 KALDI_ERR <<
"Expected field name=<input-name> in config line: " 431 node_names_.push_back(name);
434 std::string input_node_name;
435 if (!config->
GetValue(
"input-node", &input_node_name))
436 KALDI_ERR <<
"Expected input-node=<input-node-name>, in config line: " 440 KALDI_ERR <<
"Expected dim=<feature-dim>, in config line: " 442 if (!config->
GetValue(
"dim-offset", &dim_offset))
443 KALDI_ERR <<
"Expected dim-offset=<dimension-offset>, in config line: " 446 int32 input_node_index = GetNodeIndex(input_node_name);
447 if (input_node_index == -1 ||
450 KALDI_ERR <<
"invalid input-node " << input_node_name
455 <<
" in config line: " << config->
WholeLine();
467 size_t size = node_names_.size();
468 for (
size_t i = 0;
i < size;
i++)
469 if (node_names_[
i] == node_name)
470 return static_cast<int32>(
i);
475 size_t size = component_names_.size();
476 for (
size_t i = 0;
i < size;
i++)
477 if (component_names_[
i] == component_name)
478 return static_cast<int32>(
i);
487 std::vector<ConfigLine> *config_lines) {
488 int32 num_lines = config_lines->size();
491 unordered_map<std::string, int32, StringHasher> node_name_to_most_recent_line;
492 unordered_set<std::string, StringHasher> component_names;
493 typedef unordered_map<std::string, int32, StringHasher>::iterator IterType;
495 std::vector<bool> to_remove(num_lines,
false);
496 for (
int32 line = 0; line < num_lines; line++) {
497 ConfigLine &config_line = (*config_lines)[line];
499 if (!config_line.
GetValue(
"name", &name))
500 KALDI_ERR <<
"Config line has no field 'name=xxx': " 503 KALDI_ERR <<
"Name '" << name <<
"' is not allowable, in line: " 505 if (config_line.
FirstToken() ==
"component") {
509 if (!component_names.insert(name).second) {
512 <<
" appears twice in the same config file.";
516 IterType iter = node_name_to_most_recent_line.find(name);
517 if (iter != node_name_to_most_recent_line.end()) {
519 int32 prev_line = iter->second;
520 if (prev_line >= num_lines_initial) {
523 <<
" appears twice in the same config file.";
529 to_remove[prev_line] =
true;
531 node_name_to_most_recent_line[name] = line;
535 std::vector<ConfigLine> config_lines_out;
536 config_lines_out.reserve(num_lines);
537 for (
int32 i = 0;
i < num_lines;
i++) {
539 config_lines_out.push_back((*config_lines)[i]);
541 config_lines->swap(config_lines_out);
555 for (
size_t i = 0;
i < components_.size();
i++)
556 delete components_[
i];
557 component_names_.clear();
564 std::vector<std::string> *modified_node_names)
const {
565 modified_node_names->resize(node_names_.size());
566 const std::string invalid_name =
"**";
567 size_t size = node_names_.size();
568 for (
size_t i = 0;
i < size;
i++) {
572 (*modified_node_names)[
i] = node_names_[
i];
574 (*modified_node_names)[
i] = invalid_name;
583 nodes_.swap(other->
nodes_);
589 if (first_char ==
'T') {
594 temp_trans_model.
Read(is, binary);
596 temp_am_nnet.
Read(is, binary);
602 std::ostringstream config_file_out;
603 std::string cur_line;
604 getline(is, cur_line);
605 if (!(cur_line ==
"" || cur_line ==
"\r"))
606 KALDI_ERR <<
"Expected newline in config file, got " << cur_line;
607 while (getline(is, cur_line)) {
609 if (cur_line ==
"" || cur_line ==
"\r")
611 config_file_out << cur_line << std::endl;
615 int32 num_components;
617 KALDI_ASSERT(num_components >= 0 && num_components < 100000);
618 components_.resize(num_components, NULL);
619 component_names_.resize(num_components);
620 for (
int32 c = 0; c < num_components; c++) {
622 ReadToken(is, binary, &(component_names_[c]));
626 std::istringstream config_file_in(config_file_out.str());
627 this->ReadConfig(config_file_in);
633 std::vector<std::string> config_lines;
634 const bool include_dim =
false;
635 GetConfigLines(include_dim, &config_lines);
636 for (
size_t i = 0;
i < config_lines.size();
i++) {
638 os << config_lines[
i] << std::endl;
643 int32 num_components = components_.size();
648 for (
int32 c = 0; c < num_components; c++) {
651 components_[c]->Write(os, binary);
660 for (
int32 n = 0;
n < NumNodes();
n++) {
670 int32 n = GetNodeIndex(input_name);
671 if (n == -1)
return -1;
678 int32 n = GetNodeIndex(input_name);
679 if (n == -1 || !IsOutputNode(n))
return -1;
681 return node.
Dim(*
this);
685 KALDI_ASSERT(static_cast<size_t>(node_index) < node_names_.size());
690 KALDI_ASSERT(static_cast<size_t>(component_index) < component_names_.size());
695 int32 num_nodes = nodes_.size(),
697 num_output_nodes = 0;
699 for (
int32 n = 0;
n < num_nodes;
n++) {
701 std::string node_name = node_names_[
n];
711 std::vector<int32> node_deps;
714 for (
size_t i = 0;
i < node_deps.size();
i++) {
715 int32 src_node = node_deps[
i];
717 NodeType src_type = nodes_[src_node].node_type;
720 KALDI_ERR <<
"Invalid source node type in Descriptor: source node " 721 << node_names_[src_node];
731 src_dim = src_node.
Dim(*
this);
733 KALDI_ERR <<
"Error in Descriptor for network-node " 734 << node_name <<
" (see error above)";
736 if (src_dim != input_dim) {
737 KALDI_ERR <<
"Dimension mismatch for network-node " 738 << node_name <<
": input-dim " 739 << src_dim <<
" versus component-input-dim " 746 KALDI_ASSERT(input_node >= 0 && input_node < num_nodes);
747 NodeType input_type = nodes_[input_node].node_type;
749 KALDI_ERR <<
"Invalid source node type in DimRange node: source node " 750 << node_names_[input_node];
751 int32 input_dim = nodes_[input_node].Dim(*
this);
754 KALDI_ERR <<
"Invalid node dimensions for DimRange node: " << node_name
755 <<
": input-dim=" << input_dim <<
", dim=" << node.
dim 761 KALDI_ERR <<
"Invalid node type for node " << node_name;
765 int32 num_components = components_.size();
766 for (
int32 c = 0; c < num_components; c++) {
767 const std::string &component_name = component_names_[c];
769 "Duplicate component names?");
775 if (warn_for_orphans) {
776 std::vector<int32> orphans;
778 for (
size_t i = 0;
i < orphans.size();
i++) {
779 KALDI_WARN <<
"Component " << GetComponentName(orphans[
i])
780 <<
" is never used by any node.";
783 for (
size_t i = 0;
i < orphans.size();
i++) {
784 if (!IsComponentInputNode(orphans[
i])) {
789 KALDI_WARN <<
"Node " << GetNodeName(orphans[i])
790 <<
" is never used to compute any output.";
798 component_names_(nnet.component_names_),
799 components_(nnet.components_.size()),
800 node_names_(nnet.node_names_),
801 nodes_(nnet.nodes_) {
822 std::ostringstream os;
825 int32 left_context, right_context;
827 os <<
"left-context: " << left_context <<
"\n";
828 os <<
"right-context: " << right_context <<
"\n";
831 os <<
"modulus: " << this->
Modulus() <<
"\n";
832 std::vector<std::string> config_lines;
833 bool include_dim =
true;
835 for (
size_t i = 0;
i < config_lines.size();
i++)
836 os << config_lines[
i] <<
"\n";
845 std::vector<int32> orphan_components;
847 KALDI_LOG <<
"Removing " << orphan_components.size()
848 <<
" orphan components.";
849 if (orphan_components.empty())
852 new_num_components = 0;
853 std::vector<int32> old2new_map(old_num_components, 0);
854 for (
size_t i = 0;
i < orphan_components.size();
i++)
855 old2new_map[orphan_components[
i]] = -1;
856 std::vector<Component*> new_components;
857 std::vector<std::string> new_component_names;
858 for (
int32 c = 0; c < old_num_components; c++) {
859 if (old2new_map[c] != -1) {
860 old2new_map[c] = new_num_components++;
871 new_c = old2new_map[old_c];
873 nodes_[
n].u.component_index = new_c;
882 if (nodes_to_remove.empty())
886 std::vector<int32> old2new_map(old_num_nodes, 0);
887 for (
size_t i = 0;
i < nodes_to_remove.size();
i++)
888 old2new_map[nodes_to_remove[
i]] = -1;
889 std::vector<NetworkNode> new_nodes;
890 std::vector<std::string> new_node_names;
891 for (
int32 n = 0;
n < old_num_nodes;
n++) {
892 if (old2new_map[
n] != -1) {
893 old2new_map[
n] = new_num_nodes++;
894 new_nodes.push_back(
nodes_[
n]);
898 for (
int32 n = 0;
n < new_num_nodes;
n++) {
903 std::ostringstream os;
905 std::vector<std::string> tokens;
908 tokens.push_back(
"end of input");
909 const std::string *token = &(tokens[0]);
912 if (!new_nodes[
n].descriptor.Parse(new_node_names, &token)) {
913 KALDI_ERR <<
"Code error removing orphan nodes.";
915 }
else if (new_nodes[
n].node_type ==
kDimRange) {
916 int32 old_node_index = new_nodes[
n].u.node_index,
917 new_node_index = old2new_map[old_node_index];
918 KALDI_ASSERT(new_node_index >= 0 && new_node_index <= new_num_nodes);
919 new_nodes[
n].u.node_index = new_node_index;
924 bool warn_for_orphans =
false;
928 Check(warn_for_orphans);
933 std::vector<int32> orphan_nodes;
935 if (!remove_orphan_inputs)
936 for (
int32 i = 0;
i < orphan_nodes.size();
i++)
938 orphan_nodes.erase(orphan_nodes.begin() +
i);
943 int32 num_nodes_removed = 0;
944 for (
int32 i = 0; i < orphan_nodes.size(); i++)
948 KALDI_LOG <<
"Removed " << num_nodes_removed <<
" orphan nodes.";
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
int32 InputDim(const std::string &input_name) const
const std::string & FirstToken() const
void GetNodeDependencies(std::vector< int32 > *node_indexes) const
const std::string WholeLine()
int32 AddComponent(const std::string &name, Component *component)
Adds a new component with the given name, which should not be the same as any existing component name...
void Write(std::ostream &ostream, bool binary) const
void ReadConfig(std::istream &config_file)
void FindOrphanComponents(const Nnet &nnet, std::vector< int32 > *components)
This function finds a list of components that are never used, and outputs the integer comopnent index...
void GetSomeNodeNames(std::vector< std::string > *modified_node_names) const
const std::string & GetNodeName(int32 node_index) const
returns individual node name.
bool Parse(const std::vector< std::string > &node_names, const std::string **next_token)
Abstract base-class for neural-net components.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
void FindOrphanNodes(const Nnet &nnet, std::vector< int32 > *nodes)
This function finds a list of nodes that are never used to compute any output, and outputs the intege...
bool DescriptorTokenize(const std::string &input, std::vector< std::string > *tokens)
This function tokenizes input when parsing Descriptor configuration values.
bool IsInputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kInput.
static void RemoveRedundantConfigLines(int32 num_lines_initial, std::vector< ConfigLine > *config_lines)
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
std::vector< std::string > component_names_
std::vector< Component * > components_
void SetComponent(int32 c, Component *component)
Replace the component indexed by c with a new component.
virtual int32 OutputDim() const =0
Returns output-dimension of this component.
bool IsComponentNode(int32 node) const
Returns true if this is a component node, meaning that it is of type kComponent.
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq's (removes duplicates) from a vector.
void RemoveOrphanComponents()
ObjectiveType objective_type
bool IsValidName(const std::string &name)
Returns true if 'name' would be a valid name for a component or node in a nnet3Nnet.
const Nnet & GetNnet() const
std::string GetAsConfigLine(int32 node_index, bool include_dim) const
void Read(std::istream &is, bool binary)
void SetNodeName(int32 node_index, const std::string &new_name)
This can be used to modify invidual node names.
std::vector< std::string > node_names_
int32 OutputDim(const std::string &output_name) const
I Lcm(I m, I n)
Returns the least common multiple of two integers.
This file contains some miscellaneous functions dealing with class Nnet.
bool IsToken(const std::string &token)
Returns true if "token" is nonempty, and all characters are printable and whitespace-free.
This file contains declarations of components that are "simple", meaning they don't care about the in...
std::string Info() const
returns some human-readable information about the network, mostly for debugging purposes.
int32 Modulus() const
[Relevant for clockwork RNNs and similar].
std::vector< NetworkNode > nodes_
void ProcessOutputNodeConfigLine(int32 pass, ConfigLine *config)
std::string UnusedValues() const
returns e.g.
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
void ParseConfigLines(const std::vector< std::string > &lines, std::vector< ConfigLine > *config_lines)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
int32 Dim(const Nnet &nnet) const
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
void Read(std::istream &istream, bool binary)
void Read(std::istream &is, bool binary)
void RemoveOrphanNodes(bool remove_orphan_inputs=false)
int32 GetComponentIndex(const std::string &node_name) const
returns index associated with this component name, or -1 if no such index.
void GetConfigLines(bool include_dim, std::vector< std::string > *config_lines) const
void ProcessComponentConfigLine(int32 initial_num_components, ConfigLine *config)
static Component * ReadNew(std::istream &is, bool binary)
Read component from stream (works out its type). Dies on error.
const std::string & GetComponentName(int32 component_index) const
returns individual component name.
void ReadConfigLines(std::istream &is, std::vector< std::string > *lines)
This function reads in a config file and *appends* its contents to a vector of lines; it is responsib...
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
const std::vector< std::string > & GetComponentNames() const
returns vector of component names (needed by some parsing code, for instance).
int PeekToken(std::istream &is, bool binary)
PeekToken will return the first character of the next token, or -1 if end of file.
NetworkNode is used to represent, three types of thing: either an input of the network (which pretty ...
Component * GetComponent(int32 c)
Return component indexed c. Not a copy; not owned by caller.
void RemoveSomeNodes(const std::vector< int32 > &nodes_to_remove)
void ProcessDimRangeNodeConfigLine(int32 pass, ConfigLine *config)
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
int32 NumComponents() const
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing='a b c' baz="a b c d='a b' e" and giving you access to the fields, in this case.
#define KALDI_ASSERT(cond)
bool HasUnusedValues() const
bool GetValue(const std::string &key, std::string *value)
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
void Check(bool warn_for_orphans=true) const
Checks the neural network for validity (dimension matches and various other requirements).
void ProcessInputNodeConfigLine(ConfigLine *config)
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
virtual int32 InputDim() const =0
Returns input-dimension of this component.
static Component * NewComponentOfType(const std::string &type)
Returns a new Component of the given type e.g.
bool IsDescriptorNode(int32 node) const
Returns true if this is a descriptor node, meaning that it is of type kDescriptor.
Nnet & operator=(const Nnet &nnet)
bool IsDimRangeNode(int32 node) const
Returns true if this is a dim-range node, meaning that it is of type kDimRange.
NetworkNode(NodeType nt=kNone)
void ProcessComponentNodeConfigLine(int32 pass, ConfigLine *config)
const std::vector< std::string > & GetNodeNames() const
returns vector of node names (needed by some parsing code, for instance).
int32 Dim(const Nnet &nnet) const
virtual void InitFromConfig(ConfigLine *cfl)=0
Initialize, from a ConfigLine object.
bool IsComponentInputNode(int32 node) const
Returns true if this is component-input node, i.e.