Nnet Class Reference

#include <nnet-nnet.h>

Collaboration diagram for Nnet:

Public Member Functions

void ReadConfig (std::istream &config_file)
 
int32 NumComponents () const
 
int32 NumNodes () const
 
ComponentGetComponent (int32 c)
 Return component indexed c. Not a copy; not owned by caller. More...
 
const ComponentGetComponent (int32 c) const
 Return component indexed c (const version). More...
 
void SetComponent (int32 c, Component *component)
 Replace the component indexed by c with a new component. More...
 
int32 AddComponent (const std::string &name, Component *component)
 Adds a new component with the given name, which should not be the same as any existing component name. More...
 
const NetworkNodeGetNode (int32 node) const
 returns const reference to a particular numbered network node. More...
 
NetworkNodeGetNode (int32 node)
 Non-const accessor for the node... use with extreme caution. More...
 
bool IsComponentNode (int32 node) const
 Returns true if this is a component node, meaning that it is of type kComponent. More...
 
bool IsDimRangeNode (int32 node) const
 Returns true if this is a dim-range node, meaning that it is of type kDimRange. More...
 
bool IsInputNode (int32 node) const
 Returns true if this is an output node, meaning that it is of type kInput. More...
 
bool IsDescriptorNode (int32 node) const
 Returns true if this is a descriptor node, meaning that it is of type kDescriptor. More...
 
bool IsOutputNode (int32 node) const
 Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly followed by a node of type kComponent. More...
 
bool IsComponentInputNode (int32 node) const
 Returns true if this is component-input node, i.e. More...
 
const std::vector< std::string > & GetNodeNames () const
 returns vector of node names (needed by some parsing code, for instance). More...
 
const std::string & GetNodeName (int32 node_index) const
 returns individual node name. More...
 
void SetNodeName (int32 node_index, const std::string &new_name)
 This can be used to modify invidual node names. More...
 
const std::vector< std::string > & GetComponentNames () const
 returns vector of component names (needed by some parsing code, for instance). More...
 
const std::string & GetComponentName (int32 component_index) const
 returns individual component name. More...
 
int32 GetNodeIndex (const std::string &node_name) const
 returns index associated with this node name, or -1 if no such index. More...
 
int32 GetComponentIndex (const std::string &node_name) const
 returns index associated with this component name, or -1 if no such index. More...
 
int32 InputDim (const std::string &input_name) const
 
int32 OutputDim (const std::string &output_name) const
 
void Read (std::istream &istream, bool binary)
 
void Write (std::ostream &ostream, bool binary) const
 
void Check (bool warn_for_orphans=true) const
 Checks the neural network for validity (dimension matches and various other requirements). More...
 
std::string Info () const
 returns some human-readable information about the network, mostly for debugging purposes. More...
 
int32 Modulus () const
 [Relevant for clockwork RNNs and similar]. More...
 
 ~Nnet ()
 
 Nnet ()
 
 Nnet (const Nnet &nnet)
 
NnetCopy () const
 
void Swap (Nnet *other)
 
Nnetoperator= (const Nnet &nnet)
 
void RemoveOrphanNodes (bool remove_orphan_inputs=false)
 
void RemoveOrphanComponents ()
 
void RemoveSomeNodes (const std::vector< int32 > &nodes_to_remove)
 
void ResetGenerators ()
 
void GetConfigLines (bool include_dim, std::vector< std::string > *config_lines) const
 

Private Member Functions

void Destroy ()
 
std::string GetAsConfigLine (int32 node_index, bool include_dim) const
 
void ProcessComponentConfigLine (int32 initial_num_components, ConfigLine *config)
 
void ProcessComponentNodeConfigLine (int32 pass, ConfigLine *config)
 
void ProcessInputNodeConfigLine (ConfigLine *config)
 
void ProcessOutputNodeConfigLine (int32 pass, ConfigLine *config)
 
void ProcessDimRangeNodeConfigLine (int32 pass, ConfigLine *config)
 
void GetSomeNodeNames (std::vector< std::string > *modified_node_names) const
 

Static Private Member Functions

static void RemoveRedundantConfigLines (int32 num_lines_initial, std::vector< ConfigLine > *config_lines)
 

Private Attributes

std::vector< std::string > component_names_
 
std::vector< Component * > components_
 
std::vector< std::string > node_names_
 
std::vector< NetworkNodenodes_
 

Detailed Description

Definition at line 115 of file nnet-nnet.h.

Constructor & Destructor Documentation

◆ ~Nnet()

~Nnet ( )
inline

Definition at line 237 of file nnet-nnet.h.

237 { Destroy(); }

◆ Nnet() [1/2]

Nnet ( )
inline

Definition at line 240 of file nnet-nnet.h.

240 { }

◆ Nnet() [2/2]

Nnet ( const Nnet nnet)

Definition at line 797 of file nnet-nnet.cc.

References Nnet::Check(), Nnet::components_, and rnnlm::i.

797  :
798  component_names_(nnet.component_names_),
799  components_(nnet.components_.size()),
800  node_names_(nnet.node_names_),
801  nodes_(nnet.nodes_) {
802  for (size_t i = 0; i < components_.size(); i++)
803  components_[i] = nnet.components_[i]->Copy();
804  Check();
805 }
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
void Check(bool warn_for_orphans=true) const
Checks the neural network for validity (dimension matches and various other requirements).
Definition: nnet-nnet.cc:694

Member Function Documentation

◆ AddComponent()

int32 AddComponent ( const std::string &  name,
Component component 
)

Adds a new component with the given name, which should not be the same as any existing component name.

Returns the new component index. Takes ownership of the pointer 'component'.

Definition at line 161 of file nnet-nnet.cc.

References kaldi::IsValidName(), and KALDI_ASSERT.

Referenced by ModelCollapser::CollapseComponentsAffine(), ModelCollapser::CollapseComponentsScale(), SvdApplier::DecomposeComponents(), ModelCollapser::GetDiagonallyPreModifiedComponentIndex(), and ModelCollapser::GetScaledComponentIndex().

162  {
163  int32 ans = components_.size();
164  KALDI_ASSERT(IsValidName(name) && component != NULL);
165  components_.push_back(component);
166  component_names_.push_back(name);
167  return ans;
168 }
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
bool IsValidName(const std::string &name)
Returns true if &#39;name&#39; would be a valid name for a component or node in a nnet3Nnet.
Definition: text-utils.cc:553
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Check()

void Check ( bool  warn_for_orphans = true) const

Checks the neural network for validity (dimension matches and various other requirements).

You can call this with warn_for_orphans = false to disable the warnings that are printed if orphan nodes or components exist.

Definition at line 694 of file nnet-nnet.cc.

References NetworkNode::component_index, NetworkNode::descriptor, NetworkNode::dim, NetworkNode::Dim(), NetworkNode::dim_offset, kaldi::nnet3::FindOrphanComponents(), kaldi::nnet3::FindOrphanNodes(), Descriptor::GetNodeDependencies(), rnnlm::i, Component::InputDim(), KALDI_ASSERT, KALDI_ERR, KALDI_WARN, kaldi::nnet3::kComponent, kaldi::nnet3::kDescriptor, kaldi::nnet3::kDimRange, kaldi::nnet3::kInput, rnnlm::n, NetworkNode::node_index, NetworkNode::node_type, kaldi::SortAndUniq(), and NetworkNode::u.

Referenced by Nnet::Nnet(), Nnet::operator=(), Nnet::RemoveOrphanComponents(), and Nnet::RemoveSomeNodes().

694  {
695  int32 num_nodes = nodes_.size(),
696  num_input_nodes = 0,
697  num_output_nodes = 0;
698  KALDI_ASSERT(num_nodes != 0);
699  for (int32 n = 0; n < num_nodes; n++) {
700  const NetworkNode &node = nodes_[n];
701  std::string node_name = node_names_[n];
702  KALDI_ASSERT(GetNodeIndex(node_name) == n);
703  switch (node.node_type) {
704  case kInput:
705  KALDI_ASSERT(node.dim > 0);
706  num_input_nodes++;
707  break;
708  case kDescriptor: {
709  if (IsOutputNode(n))
710  num_output_nodes++;
711  std::vector<int32> node_deps;
712  node.descriptor.GetNodeDependencies(&node_deps);
713  SortAndUniq(&node_deps);
714  for (size_t i = 0; i < node_deps.size(); i++) {
715  int32 src_node = node_deps[i];
716  KALDI_ASSERT(src_node >= 0 && src_node < num_nodes);
717  NodeType src_type = nodes_[src_node].node_type;
718  if (src_type != kInput && src_type != kDimRange &&
719  src_type != kComponent)
720  KALDI_ERR << "Invalid source node type in Descriptor: source node "
721  << node_names_[src_node];
722  }
723  break;
724  }
725  case kComponent: {
726  KALDI_ASSERT(n > 0 && nodes_[n-1].node_type == kDescriptor);
727  const NetworkNode &src_node = nodes_[n-1];
728  const Component *c = GetComponent(node.u.component_index);
729  int32 src_dim, input_dim = c->InputDim();
730  try {
731  src_dim = src_node.Dim(*this);
732  } catch (...) {
733  KALDI_ERR << "Error in Descriptor for network-node "
734  << node_name << " (see error above)";
735  }
736  if (src_dim != input_dim) {
737  KALDI_ERR << "Dimension mismatch for network-node "
738  << node_name << ": input-dim "
739  << src_dim << " versus component-input-dim "
740  << input_dim;
741  }
742  break;
743  }
744  case kDimRange: {
745  int32 input_node = node.u.node_index;
746  KALDI_ASSERT(input_node >= 0 && input_node < num_nodes);
747  NodeType input_type = nodes_[input_node].node_type;
748  if (input_type != kInput && input_type != kComponent)
749  KALDI_ERR << "Invalid source node type in DimRange node: source node "
750  << node_names_[input_node];
751  int32 input_dim = nodes_[input_node].Dim(*this);
752  if (!(node.dim > 0 && node.dim_offset >= 0 &&
753  node.dim + node.dim_offset <= input_dim)) {
754  KALDI_ERR << "Invalid node dimensions for DimRange node: " << node_name
755  << ": input-dim=" << input_dim << ", dim=" << node.dim
756  << ", dim-offset=" << node.dim_offset;
757  }
758  break;
759  }
760  default:
761  KALDI_ERR << "Invalid node type for node " << node_name;
762  }
763  }
764 
765  int32 num_components = components_.size();
766  for (int32 c = 0; c < num_components; c++) {
767  const std::string &component_name = component_names_[c];
768  KALDI_ASSERT(GetComponentIndex(component_name) == c &&
769  "Duplicate component names?");
770  }
771  KALDI_ASSERT(num_input_nodes > 0);
772  KALDI_ASSERT(num_output_nodes > 0);
773 
774 
775  if (warn_for_orphans) {
776  std::vector<int32> orphans;
777  FindOrphanComponents(*this, &orphans);
778  for (size_t i = 0; i < orphans.size(); i++) {
779  KALDI_WARN << "Component " << GetComponentName(orphans[i])
780  << " is never used by any node.";
781  }
782  FindOrphanNodes(*this, &orphans);
783  for (size_t i = 0; i < orphans.size(); i++) {
784  if (!IsComponentInputNode(orphans[i])) {
785  // There is no point warning about component-input nodes, since the
786  // warning will be printed for the corresponding component nodes.. a
787  // duplicate warning might be confusing to the user, as the
788  // component-input nodes are implicit and usually hidden from users.
789  KALDI_WARN << "Node " << GetNodeName(orphans[i])
790  << " is never used to compute any output.";
791  }
792  }
793  }
794 }
void FindOrphanComponents(const Nnet &nnet, std::vector< int32 > *components)
This function finds a list of components that are never used, and outputs the integer comopnent index...
Definition: nnet-utils.cc:591
const std::string & GetNodeName(int32 node_index) const
returns individual node name.
Definition: nnet-nnet.cc:684
void FindOrphanNodes(const Nnet &nnet, std::vector< int32 > *nodes)
This function finds a list of nodes that are never used to compute any output, and outputs the intege...
Definition: nnet-utils.cc:607
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
void SortAndUniq(std::vector< T > *vec)
Sorts and uniq&#39;s (removes duplicates) from a vector.
Definition: stl-utils.h:39
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
struct rnnlm::@11::@12 n
int32 GetComponentIndex(const std::string &node_name) const
returns index associated with this component name, or -1 if no such index.
Definition: nnet-nnet.cc:474
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
const std::string & GetComponentName(int32 component_index) const
returns individual component name.
Definition: nnet-nnet.cc:689
Component * GetComponent(int32 c)
Return component indexed c. Not a copy; not owned by caller.
Definition: nnet-nnet.cc:150
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
bool IsComponentInputNode(int32 node) const
Returns true if this is component-input node, i.e.
Definition: nnet-nnet.cc:172

◆ Copy()

Nnet* Copy ( ) const
inline

◆ Destroy()

void Destroy ( )
private

Definition at line 554 of file nnet-nnet.cc.

References rnnlm::i.

Referenced by Nnet::operator=().

554  {
555  for (size_t i = 0; i < components_.size(); i++)
556  delete components_[i];
557  component_names_.clear();
558  components_.clear();
559  node_names_.clear();
560  nodes_.clear();
561 }
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340

◆ GetAsConfigLine()

std::string GetAsConfigLine ( int32  node_index,
bool  include_dim 
) const
private

Definition at line 71 of file nnet-nnet.cc.

References KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kComponent, kaldi::nnet3::kDescriptor, kaldi::nnet3::kDimRange, kaldi::nnet3::kInput, kaldi::nnet3::kLinear, NetworkNode::node_index, and NetworkNode::node_type.

71  {
72  std::ostringstream ans;
73  KALDI_ASSERT(node_index < nodes_.size() &&
74  nodes_.size() == node_names_.size());
75  const NetworkNode &node = nodes_[node_index];
76  const std::string &name = node_names_[node_index];
77  switch (node.node_type) {
78  case kInput:
79  ans << "input-node name=" << name << " dim=" << node.dim;
80  break;
81  case kDescriptor:
82  // assert that it's an output-descriptor, not one describing the input to
83  // a component-node.
84  KALDI_ASSERT(IsOutputNode(node_index));
85  ans << "output-node name=" << name << " input=";
86  node.descriptor.WriteConfig(ans, node_names_);
87  if (include_dim)
88  ans << " dim=" << node.Dim(*this);
89  ans << " objective=" << (node.u.objective_type == kLinear ? "linear" :
90  "quadratic");
91  break;
92  case kComponent:
93  ans << "component-node name=" << name << " component="
94  << component_names_[node.u.component_index] << " input=";
95  KALDI_ASSERT(nodes_[node_index-1].node_type == kDescriptor);
96  nodes_[node_index-1].descriptor.WriteConfig(ans, node_names_);
97  if (include_dim)
98  ans << " input-dim=" << nodes_[node_index-1].Dim(*this)
99  << " output-dim=" << node.Dim(*this);
100  break;
101  case kDimRange:
102  ans << "dim-range-node name=" << name << " input-node="
103  << node_names_[node.u.node_index] << " dim-offset="
104  << node.dim_offset << " dim=" << node.dim;
105  break;
106  default:
107  KALDI_ERR << "Unknown node type.";
108  }
109  return ans.str();
110 }
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetComponent() [1/2]

Component * GetComponent ( int32  c)

Return component indexed c. Not a copy; not owned by caller.

Definition at line 150 of file nnet-nnet.cc.

References KALDI_ASSERT.

Referenced by Compiler::AddBackwardStepComponent(), ComputationGraphBuilder::AddDependencies(), Compiler::AddForwardStepComponent(), kaldi::nnet3::AddNnet(), kaldi::nnet3::AddNnetComponents(), kaldi::nnet3::ApplyL2Regularization(), ComputationChecker::CheckComputationCompression(), ComputationChecker::CheckComputationIndexes(), ModelCollapser::CollapseComponentsAffine(), ModelCollapser::CollapseComponentsBatchnorm(), ModelCollapser::CollapseComponentsDropout(), ModelCollapser::CollapseComponentsScale(), kaldi::nnet3::ComponentDotProducts(), kaldi::nnet3::ComputeCommandAttributes(), ComputationGraphBuilder::ComputeComputableInfo(), kaldi::nnet3::ComputeComputationGraph(), Compiler::ComputeDerivNeeded(), ComputationExpander::ComputePrecomputedIndexes(), kaldi::nnet3::ConsolidateMemory(), ModelUpdateConsolidator::ConsolidateModelUpdate(), kaldi::nnet3::ConstrainOrthonormal(), kaldi::nnet3::ConvertRepeatedToBlockAffine(), NnetComputer::DebugAfterExecute(), NnetComputer::DebugBeforeExecute(), SvdApplier::DecomposeComponents(), NetworkNode::Dim(), kaldi::nnet3::DotProduct(), NnetComputer::ExecuteCommand(), kaldi::nnet3::FreezeNaturalGradient(), ModelCollapser::GetDiagonallyPreModifiedComponentIndex(), ModelCollapser::GetScaledComponentIndex(), Compiler::GetStrideType(), kaldi::nnet3::HasBatchnorm(), VariableMergingOptimizer::MergeVariables(), DerivativeTimeLimiter::ModifyCommand(), kaldi::nnet3::NnetParametersAreIdentical(), kaldi::nnet3::NumParameters(), kaldi::nnet3::NumUpdatableComponents(), kaldi::nnet3::PerturbParams(), MaxChangeStats::Print(), kaldi::nnet3::PrintVectorPerUpdatableComponent(), ComputationStepsComputer::ProcessComponentStep(), MemoryCompressionOptimizer::ProcessMatrix(), kaldi::nnet3::ReadEditConfig(), kaldi::nnet3::ReduceRankOfComponents(), Nnet::ResetGenerators(), kaldi::nnet3::ResetGenerators(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutProportion(), kaldi::nnet3::SetDropoutTestMode(), kaldi::nnet3::SetLearningRate(), kaldi::nnet3::SetNnetAsGradient(), kaldi::nnet3::SetRequireDirectInput(), Compiler::SetUpPrecomputedIndexes(), kaldi::nnet3::UnitTestConvertRepeatedToBlockAffine(), kaldi::nnet3::UnitTestConvertRepeatedToBlockAffineComposite(), kaldi::nnet3::UnVectorizeNnet(), kaldi::nnet3::UpdateNnetWithMaxChange(), kaldi::nnet3::VectorizeNnet(), and kaldi::nnet3::ZeroComponentStats().

150  {
151  KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
152  return components_[c];
153 }
std::vector< Component * > components_
Definition: nnet-nnet.h:331
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetComponent() [2/2]

const Component * GetComponent ( int32  c) const

Return component indexed c (const version).

Not a copy; not owned by caller.

Definition at line 145 of file nnet-nnet.cc.

References KALDI_ASSERT.

145  {
146  KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
147  return components_[c];
148 }
std::vector< Component * > components_
Definition: nnet-nnet.h:331
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetComponentIndex()

int32 GetComponentIndex ( const std::string &  node_name) const

returns index associated with this component name, or -1 if no such index.

Definition at line 474 of file nnet-nnet.cc.

References rnnlm::i.

Referenced by ModelCollapser::CollapseComponentsAffine(), ModelCollapser::CollapseComponentsScale(), SvdApplier::DecomposeComponents(), ModelCollapser::GetDiagonallyPreModifiedComponentIndex(), and ModelCollapser::GetScaledComponentIndex().

474  {
475  size_t size = component_names_.size();
476  for (size_t i = 0; i < size; i++)
477  if (component_names_[i] == component_name)
478  return static_cast<int32>(i);
479  return -1;
480 }
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326

◆ GetComponentName()

◆ GetComponentNames()

const std::vector< std::string > & GetComponentNames ( ) const

returns vector of component names (needed by some parsing code, for instance).

Definition at line 67 of file nnet-nnet.cc.

67  {
68  return component_names_;
69 }
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326

◆ GetConfigLines()

void GetConfigLines ( bool  include_dim,
std::vector< std::string > *  config_lines 
) const

Definition at line 180 of file nnet-nnet.cc.

References rnnlm::n.

Referenced by Nnet::Info(), and kaldi::nnet3::ModifyNnetIvectorPeriod().

181  {
182  config_lines->clear();
183  for (int32 n = 0; n < NumNodes(); n++)
184  if (!IsComponentInputNode(n))
185  config_lines->push_back(GetAsConfigLine(n, include_dim));
186 
187 }
int32 NumNodes() const
Definition: nnet-nnet.h:126
kaldi::int32 int32
std::string GetAsConfigLine(int32 node_index, bool include_dim) const
Definition: nnet-nnet.cc:71
struct rnnlm::@11::@12 n
bool IsComponentInputNode(int32 node) const
Returns true if this is component-input node, i.e.
Definition: nnet-nnet.cc:172

◆ GetNode() [1/2]

◆ GetNode() [2/2]

NetworkNode& GetNode ( int32  node)
inline

Non-const accessor for the node... use with extreme caution.

Definition at line 152 of file nnet-nnet.h.

References NetworkNode::component_index, KALDI_ASSERT, and NetworkNode::node_index.

152  {
153  KALDI_ASSERT(node >= 0 && node < nodes_.size());
154  return nodes_[node];
155  }
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetNodeIndex()

◆ GetNodeName()

const std::string & GetNodeName ( int32  node_index) const

returns individual node name.

Definition at line 684 of file nnet-nnet.cc.

References KALDI_ASSERT, and NetworkNode::node_index.

Referenced by NnetComputer::CheckNoPendingIo(), kaldi::nnet3::PrintCommand(), ComputationStepsComputer::ProcessInputOrOutputStep(), and kaldi::nnet3::ReadEditConfig().

684  {
685  KALDI_ASSERT(static_cast<size_t>(node_index) < node_names_.size());
686  return node_names_[node_index];
687 }
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetNodeNames()

const std::vector< std::string > & GetNodeNames ( ) const

returns vector of node names (needed by some parsing code, for instance).

Definition at line 63 of file nnet-nnet.cc.

Referenced by Compiler::ComputeDerivNeeded(), Compiler::DeallocateMatrices(), kaldi::nnet3::EvaluateComputationRequest(), kaldi::nnet3::HasXentOutputs(), SvdApplier::ModifyTopology(), kaldi::nnet3::PrintComputationPreamble(), and ModelCollapser::ReplaceNodeInDescriptor().

63  {
64  return node_names_;
65 }
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337

◆ GetSomeNodeNames()

void GetSomeNodeNames ( std::vector< std::string > *  modified_node_names) const
private

Definition at line 563 of file nnet-nnet.cc.

References rnnlm::i, kaldi::nnet3::kComponent, kaldi::nnet3::kDimRange, kaldi::nnet3::kInput, and NetworkNode::node_type.

564  {
565  modified_node_names->resize(node_names_.size());
566  const std::string invalid_name = "**";
567  size_t size = node_names_.size();
568  for (size_t i = 0; i < size; i++) {
569  if (nodes_[i].node_type == kComponent ||
570  nodes_[i].node_type == kInput ||
571  nodes_[i].node_type == kDimRange) {
572  (*modified_node_names)[i] = node_names_[i];
573  } else {
574  (*modified_node_names)[i] = invalid_name;
575  }
576  }
577 }
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340

◆ Info()

std::string Info ( ) const

returns some human-readable information about the network, mostly for debugging purposes.

Also see function NnetInfo() in nnet-utils.h, which prints out more extensive infoformation.

Definition at line 821 of file nnet-nnet.cc.

References Nnet::component_names_, Nnet::components_, kaldi::nnet3::ComputeSimpleNnetContext(), Nnet::GetConfigLines(), rnnlm::i, kaldi::nnet3::IsSimpleNnet(), Nnet::Modulus(), and kaldi::nnet3::NumParameters().

Referenced by AmNnetSimple::Info(), main(), kaldi::nnet3::NnetInfo(), kaldi::nnet3::TestNnetDecodable(), and kaldi::nnet3::UnitTestNnetCompileLooped().

821  {
822  std::ostringstream os;
823 
824  if(IsSimpleNnet(*this)) {
825  int32 left_context, right_context;
826  ComputeSimpleNnetContext(*this, &left_context, &right_context);
827  os << "left-context: " << left_context << "\n";
828  os << "right-context: " << right_context << "\n";
829  }
830  os << "num-parameters: " << NumParameters(*this) << "\n";
831  os << "modulus: " << this->Modulus() << "\n";
832  std::vector<std::string> config_lines;
833  bool include_dim = true;
834  GetConfigLines(include_dim, &config_lines);
835  for (size_t i = 0; i < config_lines.size(); i++)
836  os << config_lines[i] << "\n";
837  // Get component info.
838  for (size_t i = 0; i < components_.size(); i++)
839  os << "component name=" << component_names_[i]
840  << " type=" << components_[i]->Info() << "\n";
841  return os.str();
842 }
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
std::string Info() const
returns some human-readable information about the network, mostly for debugging purposes.
Definition: nnet-nnet.cc:821
int32 Modulus() const
[Relevant for clockwork RNNs and similar].
Definition: nnet-nnet.cc:658
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
Definition: nnet-utils.cc:359
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146
void GetConfigLines(bool include_dim, std::vector< std::string > *config_lines) const
Definition: nnet-nnet.cc:180
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
Definition: nnet-utils.cc:52

◆ InputDim()

int32 InputDim ( const std::string &  input_name) const

Definition at line 669 of file nnet-nnet.cc.

References NetworkNode::dim, kaldi::nnet3::kInput, rnnlm::n, and NetworkNode::node_type.

Referenced by BatchedXvectorComputer::BatchedXvectorComputer(), kaldi::nnet3::ComputeExampleComputationRequestSimple(), kaldi::nnet3::CreateLoopedComputationRequest(), DecodableNnetLoopedOnlineBase::DecodableNnetLoopedOnlineBase(), AmNnetSimple::Info(), DecodableNnetSimpleLoopedInfo::Init(), AmNnetSimple::InputDim(), AmNnetSimple::IvectorDim(), NnetBatchComputer::NnetBatchComputer(), kaldi::nnet3::NnetInfo(), and kaldi::nnet3::TestNnetDecodable().

669  {
670  int32 n = GetNodeIndex(input_name);
671  if (n == -1) return -1;
672  const NetworkNode &node = nodes_[n];
673  if (node.node_type != kInput) return -1;
674  return node.dim;
675 }
kaldi::int32 int32
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
struct rnnlm::@11::@12 n
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ IsComponentInputNode()

bool IsComponentInputNode ( int32  node) const

Returns true if this is component-input node, i.e.

a node of type kDescriptor that immediately precedes a node of type kComponent.

Definition at line 172 of file nnet-nnet.cc.

References KALDI_ASSERT, kaldi::nnet3::kComponent, kaldi::nnet3::kDescriptor, and NetworkNode::node_type.

Referenced by Compiler::GetStrideType(), SvdApplier::ModifyTopology(), ComputationStepsComputer::ProcessSubPhase(), and Nnet::RemoveOrphanNodes().

172  {
173  int32 size = nodes_.size();
174  KALDI_ASSERT(node >= 0 && node < size);
175  return (node + 1 < size &&
176  nodes_[node].node_type == kDescriptor &&
177  nodes_[node+1].node_type == kComponent);
178 }
kaldi::int32 int32
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ IsComponentNode()

bool IsComponentNode ( int32  node) const

Returns true if this is a component node, meaning that it is of type kComponent.

Definition at line 132 of file nnet-nnet.cc.

References KALDI_ASSERT, kaldi::nnet3::kComponent, and NetworkNode::node_type.

Referenced by Compiler::ComputeDerivNeeded(), Compiler::ComputeStepDependencies(), kaldi::nnet3::FindOrphanComponents(), Compiler::GetStrideType(), SvdApplier::ModifyTopology(), ComputationStepsComputer::ProcessComponentStep(), ComputationStepsComputer::ProcessSubPhase(), and Nnet::RemoveOrphanComponents().

132  {
133  int32 size = nodes_.size();
134  KALDI_ASSERT(node >= 0 && node < size);
135  return (nodes_[node].node_type == kComponent);
136 }
kaldi::int32 int32
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ IsDescriptorNode()

bool IsDescriptorNode ( int32  node) const

Returns true if this is a descriptor node, meaning that it is of type kDescriptor.

Exactly one of IsOutput or IsComponentInput will also apply.

Definition at line 126 of file nnet-nnet.cc.

References KALDI_ASSERT, kaldi::nnet3::kDescriptor, and NetworkNode::node_type.

126  {
127  int32 size = nodes_.size();
128  KALDI_ASSERT(node >= 0 && node < size);
129  return (nodes_[node].node_type == kDescriptor);
130 }
kaldi::int32 int32
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ IsDimRangeNode()

bool IsDimRangeNode ( int32  node) const

Returns true if this is a dim-range node, meaning that it is of type kDimRange.

Definition at line 138 of file nnet-nnet.cc.

References KALDI_ASSERT, kaldi::nnet3::kDimRange, and NetworkNode::node_type.

Referenced by ComputationStepsComputer::ProcessDimRangeSubPhase(), and ComputationStepsComputer::ProcessSubPhase().

138  {
139  int32 size = nodes_.size();
140  KALDI_ASSERT(node >= 0 && node < size);
141  return (nodes_[node].node_type == kDimRange);
142 }
kaldi::int32 int32
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ IsInputNode()

◆ IsOutputNode()

bool IsOutputNode ( int32  node) const

Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly followed by a node of type kComponent.

Definition at line 112 of file nnet-nnet.cc.

References KALDI_ASSERT, kaldi::nnet3::kComponent, kaldi::nnet3::kDescriptor, and NetworkNode::node_type.

Referenced by Compiler::AllocateMatrices(), ComputationGraphBuilder::Check(), ComputationChecker::CheckComputationIndexes(), Compiler::CompileBackwardDescriptor(), Compiler::CompileForwardDescriptor(), Compiler::ComputeDerivNeeded(), ComputationGraphBuilder::ComputeRequiredArray(), Compiler::DeallocateMatrices(), kaldi::nnet3::FindOrphanNodes(), kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::GetComputationRequest(), kaldi::nnet3::GetDiscriminativeComputationRequest(), kaldi::nnet3::HasXentOutputs(), kaldi::nnet3::IsSimpleNnet(), SvdApplier::ModifyTopology(), kaldi::nnet3::NumOutputNodes(), ComputationStepsComputer::ProcessInputOrOutputStep(), NnetDiscriminativeComputeObjf::ProcessOutputs(), NnetChainTrainer::ProcessOutputs(), NnetChainComputeProb::ProcessOutputs(), NnetDiscriminativeTrainer::ProcessOutputs(), NnetComputeProb::ProcessOutputs(), NnetTrainer::ProcessOutputs(), ComputationStepsComputer::ProcessSubPhase(), and kaldi::nnet3::ReadEditConfig().

112  {
113  int32 size = nodes_.size();
114  KALDI_ASSERT(node >= 0 && node < size);
115  return (nodes_[node].node_type == kDescriptor &&
116  (node + 1 == size ||
117  nodes_[node + 1].node_type != kComponent));
118 }
kaldi::int32 int32
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Modulus()

int32 Modulus ( ) const

[Relevant for clockwork RNNs and similar].

Computes the smallest integer n >=1 such that the neural net's behavior will be the same if we shift the input and output's time indexes (t) by integer multiples of n. Does this by computing the lcm of all the moduli of the Descriptors in the network.

Definition at line 658 of file nnet-nnet.cc.

References NetworkNode::descriptor, kaldi::nnet3::kDescriptor, kaldi::Lcm(), Descriptor::Modulus(), rnnlm::n, and NetworkNode::node_type.

Referenced by kaldi::nnet3::ComputeSimpleNnetContext(), kaldi::nnet3::ComputeSimpleNnetContextForShift(), kaldi::nnet3::CreateLoopedComputationRequest(), kaldi::nnet3::GetChunkSize(), Nnet::Info(), and NnetBatchComputer::NnetBatchComputer().

658  {
659  int32 ans = 1;
660  for (int32 n = 0; n < NumNodes(); n++) {
661  const NetworkNode &node = nodes_[n];
662  if (node.node_type == kDescriptor)
663  ans = Lcm(ans, node.descriptor.Modulus());
664  }
665  return ans;
666 }
int32 NumNodes() const
Definition: nnet-nnet.h:126
kaldi::int32 int32
I Lcm(I m, I n)
Returns the least common multiple of two integers.
Definition: kaldi-math.h:318
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
struct rnnlm::@11::@12 n

◆ NumComponents()

int32 NumComponents ( ) const
inline

Definition at line 124 of file nnet-nnet.h.

Referenced by kaldi::nnet3::AddNnet(), kaldi::nnet3::AddNnetComponents(), kaldi::nnet3::ApplyL2Regularization(), ComputationChecker::CheckComputationIndexes(), ModelCollapser::Collapse(), kaldi::nnet3::ComponentDotProducts(), kaldi::nnet3::ConsolidateMemory(), ModelUpdateConsolidator::ConsolidateModelUpdate(), kaldi::nnet3::ConstrainOrthonormal(), kaldi::nnet3::ConvertRepeatedToBlockAffine(), SvdApplier::DecomposeComponents(), kaldi::nnet3::DotProduct(), kaldi::nnet3::FindOrphanComponents(), kaldi::nnet3::FreezeNaturalGradient(), kaldi::nnet3::HasBatchnorm(), kaldi::nnet3::NnetParametersAreIdentical(), kaldi::nnet3::NumParameters(), kaldi::nnet3::NumUpdatableComponents(), kaldi::nnet3::PerturbParams(), MaxChangeStats::Print(), kaldi::nnet3::PrintVectorPerUpdatableComponent(), kaldi::nnet3::ReadEditConfig(), kaldi::nnet3::ReduceRankOfComponents(), Nnet::ResetGenerators(), kaldi::nnet3::ResetGenerators(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutProportion(), kaldi::nnet3::SetDropoutTestMode(), kaldi::nnet3::SetLearningRate(), kaldi::nnet3::SetNnetAsGradient(), kaldi::nnet3::SetRequireDirectInput(), kaldi::nnet3::UnitTestConvertRepeatedToBlockAffine(), kaldi::nnet3::UnitTestConvertRepeatedToBlockAffineComposite(), kaldi::nnet3::UnVectorizeNnet(), kaldi::nnet3::UpdateNnetWithMaxChange(), kaldi::nnet3::VectorizeNnet(), and kaldi::nnet3::ZeroComponentStats().

124 { return components_.size(); }
std::vector< Component * > components_
Definition: nnet-nnet.h:331

◆ NumNodes()

◆ operator=()

Nnet & operator= ( const Nnet nnet)

Definition at line 807 of file nnet-nnet.cc.

References Nnet::Check(), Nnet::component_names_, Nnet::components_, Nnet::Destroy(), rnnlm::i, Nnet::node_names_, and Nnet::nodes_.

807  {
808  if (this == &nnet)
809  return *this;
810  Destroy();
811  component_names_ = nnet.component_names_;
812  components_.resize(nnet.components_.size());
813  node_names_ = nnet.node_names_;
814  nodes_ = nnet.nodes_;
815  for (size_t i = 0; i < components_.size(); i++)
816  components_[i] = nnet.components_[i]->Copy();
817  Check();
818  return *this;
819 }
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
void Check(bool warn_for_orphans=true) const
Checks the neural network for validity (dimension matches and various other requirements).
Definition: nnet-nnet.cc:694

◆ OutputDim()

int32 OutputDim ( const std::string &  output_name) const

Definition at line 677 of file nnet-nnet.cc.

References NetworkNode::Dim(), and rnnlm::n.

Referenced by BatchedXvectorComputer::BatchedXvectorComputer(), AmNnetSimple::Info(), DecodableNnetSimpleLoopedInfo::Init(), main(), NnetBatchComputer::NnetBatchComputer(), kaldi::nnet3::NnetInfo(), AmNnetSimple::NumPdfs(), AmNnetSimple::SetNnet(), AmNnetSimple::SetPriors(), kaldi::nnet3::TestNnetDecodable(), kaldi::nnet3::UnitTestNnetInputDerivatives(), and kaldi::nnet3::UnitTestNnetModelDerivatives().

677  {
678  int32 n = GetNodeIndex(input_name);
679  if (n == -1 || !IsOutputNode(n)) return -1;
680  const NetworkNode &node = nodes_[n];
681  return node.Dim(*this);
682 }
kaldi::int32 int32
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
struct rnnlm::@11::@12 n
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ ProcessComponentConfigLine()

void ProcessComponentConfigLine ( int32  initial_num_components,
ConfigLine config 
)
private

Definition at line 247 of file nnet-nnet.cc.

References ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), Component::InitFromConfig(), kaldi::IsToken(), KALDI_ERR, Component::NewComponentOfType(), ConfigLine::UnusedValues(), and ConfigLine::WholeLine().

249  {
250  std::string name, type;
251  if (!config->GetValue("name", &name))
252  KALDI_ERR << "Expected field name=<component-name> in config line: "
253  << config->WholeLine();
254  if (!IsToken(name)) // e.g. contains a space.
255  KALDI_ERR << "Component name '" << name << "' is not allowed, in line: "
256  << config->WholeLine();
257  if (!config->GetValue("type", &type))
258  KALDI_ERR << "Expected field type=<component-type> in config line: "
259  << config->WholeLine();
260  Component *new_component = Component::NewComponentOfType(type);
261  if (new_component == NULL)
262  KALDI_ERR << "Unknown component-type '" << type
263  << "' in config file. Check your code version and config.";
264  // the next call will call KALDI_ERR or KALDI_ASSERT and die if something
265  // went wrong.
266  new_component->InitFromConfig(config);
267  int32 index = GetComponentIndex(name);
268  if (index != -1) { // Replacing existing component.
269  if (index >= initial_num_components) {
270  // that index was something we added from this config.
271  KALDI_ERR << "You are adding two components with the same name: '"
272  << name << "'";
273  }
274  delete components_[index];
275  components_[index] = new_component;
276  } else {
277  components_.push_back(new_component);
278  component_names_.push_back(name);
279  }
280  if (config->HasUnusedValues())
281  KALDI_ERR << "Unused values '" << config->UnusedValues()
282  << "' in config line: " << config->WholeLine();
283 }
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
bool IsToken(const std::string &token)
Returns true if "token" is nonempty, and all characters are printable and whitespace-free.
Definition: text-utils.cc:105
int32 GetComponentIndex(const std::string &node_name) const
returns index associated with this component name, or -1 if no such index.
Definition: nnet-nnet.cc:474
#define KALDI_ERR
Definition: kaldi-error.h:147
static Component * NewComponentOfType(const std::string &type)
Returns a new Component of the given type e.g.

◆ ProcessComponentNodeConfigLine()

void ProcessComponentNodeConfigLine ( int32  pass,
ConfigLine config 
)
private

Definition at line 286 of file nnet-nnet.cc.

References NetworkNode::component_index, NetworkNode::descriptor, kaldi::nnet3::DescriptorTokenize(), ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kComponent, kaldi::nnet3::kDescriptor, NetworkNode::NetworkNode(), NetworkNode::node_index, Descriptor::Parse(), ConfigLine::UnusedValues(), and ConfigLine::WholeLine().

288  {
289 
290  std::string name;
291  if (!config->GetValue("name", &name))
292  KALDI_ERR << "Expected field name=<component-name> in config line: "
293  << config->WholeLine();
294 
295  std::string input_name = name + std::string("_input");
296  int32 input_node_index = GetNodeIndex(input_name),
297  node_index = GetNodeIndex(name);
298 
299  if (pass == 0) {
300  KALDI_ASSERT(input_node_index == -1 && node_index == -1);
301  // just set up the node types and names for now, we'll properly set them up
302  // on pass 1.
303  nodes_.push_back(NetworkNode(kDescriptor));
304  nodes_.push_back(NetworkNode(kComponent));
305  node_names_.push_back(input_name);
306  node_names_.push_back(name);
307  return;
308  } else {
309  KALDI_ASSERT(input_node_index != -1 && node_index == input_node_index + 1);
310  std::string component_name, input_descriptor;
311  if (!config->GetValue("component", &component_name))
312  KALDI_ERR << "Expected component=<component-name>, in config line: "
313  << config->WholeLine();
314  int32 component_index = GetComponentIndex(component_name);
315  if (component_index == -1)
316  KALDI_ERR << "No component named '" << component_name
317  << "', in config line: " << config->WholeLine();
318  nodes_[node_index].u.component_index = component_index;
319 
320  if (!config->GetValue("input", &input_descriptor))
321  KALDI_ERR << "Expected input=<input-descriptor>, in config line: "
322  << config->WholeLine();
323  std::vector<std::string> tokens;
324  if (!DescriptorTokenize(input_descriptor, &tokens))
325  KALDI_ERR << "Error tokenizing descriptor in config line "
326  << config->WholeLine();
327  std::vector<std::string> node_names_temp;
328  GetSomeNodeNames(&node_names_temp);
329  tokens.push_back("end of input");
330  const std::string *next_token = &(tokens[0]);
331  if (!nodes_[input_node_index].descriptor.Parse(node_names_temp,
332  &next_token))
333  KALDI_ERR << "Error parsing Descriptor in config line: "
334  << config->WholeLine();
335  if (config->HasUnusedValues())
336  KALDI_ERR << "Unused values '" << config->UnusedValues()
337  << " in config line: " << config->WholeLine();
338  }
339 }
void GetSomeNodeNames(std::vector< std::string > *modified_node_names) const
Definition: nnet-nnet.cc:563
bool DescriptorTokenize(const std::string &input, std::vector< std::string > *tokens)
This function tokenizes input when parsing Descriptor configuration values.
Definition: nnet-parse.cc:30
kaldi::int32 int32
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
int32 GetComponentIndex(const std::string &node_name) const
returns index associated with this component name, or -1 if no such index.
Definition: nnet-nnet.cc:474
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ ProcessDimRangeNodeConfigLine()

void ProcessDimRangeNodeConfigLine ( int32  pass,
ConfigLine config 
)
private

Definition at line 420 of file nnet-nnet.cc.

References NetworkNode::dim, NetworkNode::dim_offset, ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kComponent, kaldi::nnet3::kDimRange, kaldi::nnet3::kInput, NetworkNode::NetworkNode(), NetworkNode::node_index, NetworkNode::node_type, NetworkNode::u, ConfigLine::UnusedValues(), and ConfigLine::WholeLine().

422  {
423  std::string name;
424  if (!config->GetValue("name", &name))
425  KALDI_ERR << "Expected field name=<input-name> in config line: "
426  << config->WholeLine();
427  int32 node_index = GetNodeIndex(name);
428  if (pass == 0) {
429  KALDI_ASSERT(node_index == -1);
430  nodes_.push_back(NetworkNode(kDimRange));
431  node_names_.push_back(name);
432  } else {
433  KALDI_ASSERT(node_index != -1);
434  std::string input_node_name;
435  if (!config->GetValue("input-node", &input_node_name))
436  KALDI_ERR << "Expected input-node=<input-node-name>, in config line: "
437  << config->WholeLine();
438  int32 dim, dim_offset;
439  if (!config->GetValue("dim", &dim))
440  KALDI_ERR << "Expected dim=<feature-dim>, in config line: "
441  << config->WholeLine();
442  if (!config->GetValue("dim-offset", &dim_offset))
443  KALDI_ERR << "Expected dim-offset=<dimension-offset>, in config line: "
444  << config->WholeLine();
445 
446  int32 input_node_index = GetNodeIndex(input_node_name);
447  if (input_node_index == -1 ||
448  !(nodes_[input_node_index].node_type == kComponent ||
449  nodes_[input_node_index].node_type == kInput))
450  KALDI_ERR << "invalid input-node " << input_node_name
451  << ": " << config->WholeLine();
452 
453  if (config->HasUnusedValues())
454  KALDI_ERR << "Unused values '" << config->UnusedValues()
455  << " in config line: " << config->WholeLine();
456 
457  NetworkNode &node = nodes_[node_index];
458  KALDI_ASSERT(node.node_type == kDimRange);
459  node.u.node_index = input_node_index;
460  node.dim = dim;
461  node.dim_offset = dim_offset;
462  }
463 }
kaldi::int32 int32
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ ProcessInputNodeConfigLine()

void ProcessInputNodeConfigLine ( ConfigLine config)
private

Definition at line 342 of file nnet-nnet.cc.

References NetworkNode::dim, ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kInput, NetworkNode::NetworkNode(), NetworkNode::node_index, ConfigLine::UnusedValues(), and ConfigLine::WholeLine().

343  {
344  std::string name;
345  if (!config->GetValue("name", &name))
346  KALDI_ERR << "Expected field name=<input-name> in config line: "
347  << config->WholeLine();
348  int32 dim;
349  if (!config->GetValue("dim", &dim))
350  KALDI_ERR << "Expected field dim=<input-dim> in config line: "
351  << config->WholeLine();
352 
353  if (config->HasUnusedValues())
354  KALDI_ERR << "Unused values '" << config->UnusedValues()
355  << " in config line: " << config->WholeLine();
356 
357  KALDI_ASSERT(GetNodeIndex(name) == -1);
358  if (dim <= 0)
359  KALDI_ERR << "Invalid dimension in config line: " << config->WholeLine();
360 
361  int32 node_index = nodes_.size();
362  nodes_.push_back(NetworkNode(kInput));
363  nodes_[node_index].dim = dim;
364  node_names_.push_back(name);
365 }
kaldi::int32 int32
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ ProcessOutputNodeConfigLine()

void ProcessOutputNodeConfigLine ( int32  pass,
ConfigLine config 
)
private

Definition at line 368 of file nnet-nnet.cc.

References NetworkNode::descriptor, kaldi::nnet3::DescriptorTokenize(), ConfigLine::GetValue(), ConfigLine::HasUnusedValues(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kDescriptor, kaldi::nnet3::kLinear, kaldi::nnet3::kQuadratic, NetworkNode::NetworkNode(), NetworkNode::node_index, NetworkNode::objective_type, Descriptor::Parse(), ConfigLine::UnusedValues(), and ConfigLine::WholeLine().

370  {
371  std::string name;
372  if (!config->GetValue("name", &name))
373  KALDI_ERR << "Expected field name=<input-name> in config line: "
374  << config->WholeLine();
375  int32 node_index = GetNodeIndex(name);
376  if (pass == 0) {
377  KALDI_ASSERT(node_index == -1);
378  nodes_.push_back(NetworkNode(kDescriptor));
379  node_names_.push_back(name);
380  } else {
381  KALDI_ASSERT(node_index != -1);
382  std::string input_descriptor;
383  if (!config->GetValue("input", &input_descriptor))
384  KALDI_ERR << "Expected input=<input-descriptor>, in config line: "
385  << config->WholeLine();
386  std::vector<std::string> tokens;
387  if (!DescriptorTokenize(input_descriptor, &tokens))
388  KALDI_ERR << "Error tokenizing descriptor in config line "
389  << config->WholeLine();
390  tokens.push_back("end of input");
391  // if the following fails it will die.
392  std::vector<std::string> node_names_temp;
393  GetSomeNodeNames(&node_names_temp);
394  const std::string *next_token = &(tokens[0]);
395  if (!nodes_[node_index].descriptor.Parse(node_names_temp, &next_token))
396  KALDI_ERR << "Error parsing descriptor (input=...) in config line "
397  << config->WholeLine();
398  std::string objective_type;
399  if (config->GetValue("objective", &objective_type)) {
400  if (objective_type == "linear") {
401  nodes_[node_index].u.objective_type = kLinear;
402  } else if (objective_type == "quadratic") {
403  nodes_[node_index].u.objective_type = kQuadratic;
404  } else {
405  KALDI_ERR << "Invalid objective type: " << objective_type;
406  }
407  } else {
408  // the default objective type is linear. This is what we use
409  // for softmax objectives; the LogSoftmaxLayer is included as the
410  // last layer, in this case.
411  nodes_[node_index].u.objective_type = kLinear;
412  }
413  if (config->HasUnusedValues())
414  KALDI_ERR << "Unused values '" << config->UnusedValues()
415  << " in config line: " << config->WholeLine();
416  }
417 }
void GetSomeNodeNames(std::vector< std::string > *modified_node_names) const
Definition: nnet-nnet.cc:563
bool DescriptorTokenize(const std::string &input, std::vector< std::string > *tokens)
This function tokenizes input when parsing Descriptor configuration values.
Definition: nnet-parse.cc:30
kaldi::int32 int32
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ Read()

void Read ( std::istream &  istream,
bool  binary 
)

Definition at line 586 of file nnet-nnet.cc.

References kaldi::nnet3::ExpectToken(), AmNnetSimple::GetNnet(), KALDI_ASSERT, KALDI_ERR, kaldi::PeekToken(), AmNnetSimple::Read(), TransitionModel::Read(), kaldi::ReadBasicType(), Component::ReadNew(), kaldi::ReadToken(), and Nnet::Swap().

Referenced by AmNnetSimple::Read(), and kaldi::nnet3::UnitTestNnetIo().

586  {
587  Destroy();
588  int first_char = PeekToken(is, binary);
589  if (first_char == 'T') {
590  // This branch is to allow '.mdl' files (containing a TransitionModel
591  // and then an AmNnetSimple) to be read where .raw files (containing
592  // just an Nnet) would be expected. This is often convenient.
593  TransitionModel temp_trans_model;
594  temp_trans_model.Read(is, binary);
595  AmNnetSimple temp_am_nnet;
596  temp_am_nnet.Read(is, binary);
597  temp_am_nnet.GetNnet().Swap(this);
598  return;
599  }
600 
601  ExpectToken(is, binary, "<Nnet3>");
602  std::ostringstream config_file_out;
603  std::string cur_line;
604  getline(is, cur_line); // Eat up a single newline.
605  if (!(cur_line == "" || cur_line == "\r"))
606  KALDI_ERR << "Expected newline in config file, got " << cur_line;
607  while (getline(is, cur_line)) {
608  // config-file part of file is terminated by an empty line.
609  if (cur_line == "" || cur_line == "\r")
610  break;
611  config_file_out << cur_line << std::endl;
612  }
613  // Now we read the Components; later we try to parse the config_lines.
614  ExpectToken(is, binary, "<NumComponents>");
615  int32 num_components;
616  ReadBasicType(is, binary, &num_components);
617  KALDI_ASSERT(num_components >= 0 && num_components < 100000);
618  components_.resize(num_components, NULL);
619  component_names_.resize(num_components);
620  for (int32 c = 0; c < num_components; c++) {
621  ExpectToken(is, binary, "<ComponentName>");
622  ReadToken(is, binary, &(component_names_[c]));
623  components_[c] = Component::ReadNew(is, binary);
624  }
625  ExpectToken(is, binary, "</Nnet3>");
626  std::istringstream config_file_in(config_file_out.str());
627  this->ReadConfig(config_file_in);
628 }
void ReadConfig(std::istream &config_file)
Definition: nnet-nnet.cc:189
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ERR
Definition: kaldi-error.h:147
static Component * ReadNew(std::istream &is, bool binary)
Read component from stream (works out its type). Dies on error.
int PeekToken(std::istream &is, bool binary)
PeekToken will return the first character of the next token, or -1 if end of file.
Definition: io-funcs.cc:170
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ReadConfig()

void ReadConfig ( std::istream &  config_file)

Definition at line 189 of file nnet-nnet.cc.

References rnnlm::i, KALDI_ERR, kaldi::nnet3::ParseConfigLines(), and kaldi::ReadConfigLines().

Referenced by main(), kaldi::nnet3::ModifyNnetIvectorPeriod(), SvdApplier::ModifyTopology(), kaldi::nnet3::UnitTestConvertRepeatedToBlockAffine(), kaldi::nnet3::UnitTestConvertRepeatedToBlockAffineComposite(), kaldi::nnet3::UnitTestNnetAnalyze(), kaldi::nnet3::UnitTestNnetCompile(), kaldi::nnet3::UnitTestNnetCompileLooped(), kaldi::nnet3::UnitTestNnetCompileMulti(), kaldi::nnet3::UnitTestNnetCompute(), kaldi::nnet3::UnitTestNnetContext(), kaldi::nnet3::UnitTestNnetInputDerivatives(), kaldi::nnet3::UnitTestNnetIo(), kaldi::nnet3::UnitTestNnetModelDerivatives(), and kaldi::nnet3::UnitTestNnetOptimizeWithOptions().

189  {
190 
191  std::vector<std::string> lines;
192  // Write into "lines" a config file corresponding to whatever
193  // nodes we currently have. Because the numbering of nodes may
194  // change, it's most convenient to convert to the text representation
195  // and combine the existing and new config lines in that representation.
196  const bool include_dim = false;
197  GetConfigLines(include_dim, &lines);
198 
199  // we'll later regenerate what we need from nodes_ and node_name_ from the
200  // string representation.
201  nodes_.clear();
202  node_names_.clear();
203 
204  int32 num_lines_initial = lines.size();
205 
206  ReadConfigLines(config_is, &lines);
207  // now "lines" will have comments removed and empty lines stripped out
208 
209  std::vector<ConfigLine> config_lines(lines.size());
210 
211  ParseConfigLines(lines, &config_lines);
212 
213  // the next line will possibly remove some elements from "config_lines" so no
214  // node or component is doubly defined, always keeping the second repeat.
215  // Things being doubly defined can happen when a previously existing node or
216  // component is redefined in a new config file.
217  RemoveRedundantConfigLines(num_lines_initial, &config_lines);
218 
219  int32 initial_num_components = components_.size();
220  for (int32 pass = 0; pass <= 1; pass++) {
221  for (size_t i = 0; i < config_lines.size(); i++) {
222  const std::string &first_token = config_lines[i].FirstToken();
223  if (first_token == "component") {
224  if (pass == 0)
225  ProcessComponentConfigLine(initial_num_components,
226  &(config_lines[i]));
227  } else if (first_token == "component-node") {
228  ProcessComponentNodeConfigLine(pass, &(config_lines[i]));
229  } else if (first_token == "input-node") {
230  if (pass == 0)
231  ProcessInputNodeConfigLine(&(config_lines[i]));
232  } else if (first_token == "output-node") {
233  ProcessOutputNodeConfigLine(pass, &(config_lines[i]));
234  } else if (first_token == "dim-range-node") {
235  ProcessDimRangeNodeConfigLine(pass, &(config_lines[i]));
236  } else {
237  KALDI_ERR << "Invalid config-file line ('" << first_token
238  << "' not expected): " << config_lines[i].WholeLine();
239  }
240  }
241  }
242  Check();
243 }
static void RemoveRedundantConfigLines(int32 num_lines_initial, std::vector< ConfigLine > *config_lines)
Definition: nnet-nnet.cc:486
kaldi::int32 int32
std::vector< Component * > components_
Definition: nnet-nnet.h:331
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
void ProcessOutputNodeConfigLine(int32 pass, ConfigLine *config)
Definition: nnet-nnet.cc:368
void ParseConfigLines(const std::vector< std::string > &lines, std::vector< ConfigLine > *config_lines)
Definition: nnet-parse.cc:224
void GetConfigLines(bool include_dim, std::vector< std::string > *config_lines) const
Definition: nnet-nnet.cc:180
void ProcessComponentConfigLine(int32 initial_num_components, ConfigLine *config)
Definition: nnet-nnet.cc:247
#define KALDI_ERR
Definition: kaldi-error.h:147
void ReadConfigLines(std::istream &is, std::vector< std::string > *lines)
This function reads in a config file and *appends* its contents to a vector of lines; it is responsib...
Definition: text-utils.cc:564
void ProcessDimRangeNodeConfigLine(int32 pass, ConfigLine *config)
Definition: nnet-nnet.cc:420
void Check(bool warn_for_orphans=true) const
Checks the neural network for validity (dimension matches and various other requirements).
Definition: nnet-nnet.cc:694
void ProcessInputNodeConfigLine(ConfigLine *config)
Definition: nnet-nnet.cc:342
void ProcessComponentNodeConfigLine(int32 pass, ConfigLine *config)
Definition: nnet-nnet.cc:286

◆ RemoveOrphanComponents()

void RemoveOrphanComponents ( )

Definition at line 844 of file nnet-nnet.cc.

References Nnet::Check(), Nnet::component_names_, Nnet::components_, kaldi::nnet3::FindOrphanComponents(), rnnlm::i, Nnet::IsComponentNode(), KALDI_ASSERT, KALDI_LOG, rnnlm::n, Nnet::nodes_, and Nnet::NumNodes().

Referenced by ModelCollapser::Collapse(), SvdApplier::ModifyTopology(), and kaldi::nnet3::ReadEditConfig().

844  {
845  std::vector<int32> orphan_components;
846  FindOrphanComponents(*this, &orphan_components);
847  KALDI_LOG << "Removing " << orphan_components.size()
848  << " orphan components.";
849  if (orphan_components.empty())
850  return;
851  int32 old_num_components = components_.size(),
852  new_num_components = 0;
853  std::vector<int32> old2new_map(old_num_components, 0);
854  for (size_t i = 0; i < orphan_components.size(); i++)
855  old2new_map[orphan_components[i]] = -1;
856  std::vector<Component*> new_components;
857  std::vector<std::string> new_component_names;
858  for (int32 c = 0; c < old_num_components; c++) {
859  if (old2new_map[c] != -1) {
860  old2new_map[c] = new_num_components++;
861  new_components.push_back(components_[c]);
862  new_component_names.push_back(component_names_[c]);
863  } else {
864  delete components_[c];
865  components_[c] = NULL;
866  }
867  }
868  for (int32 n = 0; n < NumNodes(); n++) {
869  if (IsComponentNode(n)) {
870  int32 old_c = nodes_[n].u.component_index,
871  new_c = old2new_map[old_c];
872  KALDI_ASSERT(new_c >= 0);
873  nodes_[n].u.component_index = new_c;
874  }
875  }
876  components_ = new_components;
877  component_names_ = new_component_names;
878  Check();
879 }
int32 NumNodes() const
Definition: nnet-nnet.h:126
void FindOrphanComponents(const Nnet &nnet, std::vector< int32 > *components)
This function finds a list of components that are never used, and outputs the integer comopnent index...
Definition: nnet-utils.cc:591
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
bool IsComponentNode(int32 node) const
Returns true if this is a component node, meaning that it is of type kComponent.
Definition: nnet-nnet.cc:132
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
struct rnnlm::@11::@12 n
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Check(bool warn_for_orphans=true) const
Checks the neural network for validity (dimension matches and various other requirements).
Definition: nnet-nnet.cc:694
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ RemoveOrphanNodes()

void RemoveOrphanNodes ( bool  remove_orphan_inputs = false)

Definition at line 932 of file nnet-nnet.cc.

References kaldi::nnet3::FindOrphanNodes(), rnnlm::i, Nnet::IsComponentInputNode(), Nnet::IsInputNode(), KALDI_LOG, and Nnet::RemoveSomeNodes().

Referenced by ModelCollapser::Collapse(), SvdApplier::ModifyTopology(), and kaldi::nnet3::ReadEditConfig().

932  {
933  std::vector<int32> orphan_nodes;
934  FindOrphanNodes(*this, &orphan_nodes);
935  if (!remove_orphan_inputs)
936  for (int32 i = 0; i < orphan_nodes.size(); i++)
937  if (IsInputNode(orphan_nodes[i]))
938  orphan_nodes.erase(orphan_nodes.begin() + i);
939  // For each component-node, its component-input node (which is kind of a
940  // "hidden" node) would be included in 'orphan_nodes', but for diagnostic
941  // purposes we want to exclude these from 'num_nodes_removed' to avoid
942  // confusing users.
943  int32 num_nodes_removed = 0;
944  for (int32 i = 0; i < orphan_nodes.size(); i++)
945  if (!IsComponentInputNode(orphan_nodes[i]))
946  num_nodes_removed++;
947  RemoveSomeNodes(orphan_nodes);
948  KALDI_LOG << "Removed " << num_nodes_removed << " orphan nodes.";
949 }
void FindOrphanNodes(const Nnet &nnet, std::vector< int32 > *nodes)
This function finds a list of nodes that are never used to compute any output, and outputs the intege...
Definition: nnet-utils.cc:607
bool IsInputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kInput.
Definition: nnet-nnet.cc:120
kaldi::int32 int32
void RemoveSomeNodes(const std::vector< int32 > &nodes_to_remove)
Definition: nnet-nnet.cc:881
#define KALDI_LOG
Definition: kaldi-error.h:153
bool IsComponentInputNode(int32 node) const
Returns true if this is component-input node, i.e.
Definition: nnet-nnet.cc:172

◆ RemoveRedundantConfigLines()

void RemoveRedundantConfigLines ( int32  num_lines_initial,
std::vector< ConfigLine > *  config_lines 
)
staticprivate

Definition at line 486 of file nnet-nnet.cc.

References ConfigLine::FirstToken(), ConfigLine::GetValue(), rnnlm::i, kaldi::IsValidName(), KALDI_ASSERT, KALDI_ERR, and ConfigLine::WholeLine().

487  {
488  int32 num_lines = config_lines->size();
489  KALDI_ASSERT(num_lines_initial <= num_lines);
490  // node names and component names live in different namespaces.
491  unordered_map<std::string, int32, StringHasher> node_name_to_most_recent_line;
492  unordered_set<std::string, StringHasher> component_names;
493  typedef unordered_map<std::string, int32, StringHasher>::iterator IterType;
494 
495  std::vector<bool> to_remove(num_lines, false);
496  for (int32 line = 0; line < num_lines; line++) {
497  ConfigLine &config_line = (*config_lines)[line];
498  std::string name;
499  if (!config_line.GetValue("name", &name))
500  KALDI_ERR << "Config line has no field 'name=xxx': "
501  << config_line.WholeLine();
502  if (!IsValidName(name))
503  KALDI_ERR << "Name '" << name << "' is not allowable, in line: "
504  << config_line.WholeLine();
505  if (config_line.FirstToken() == "component") {
506  // a line starting with "component"... components live in their own
507  // namespace. No repeats are allowed because we never wrote them
508  // to the config generated from the nnet.
509  if (!component_names.insert(name).second) {
510  // we could not insert it because it was already there.
511  KALDI_ERR << "Component name " << name
512  << " appears twice in the same config file.";
513  }
514  } else {
515  // the line defines some sort of network node, e.g. component-node.
516  IterType iter = node_name_to_most_recent_line.find(name);
517  if (iter != node_name_to_most_recent_line.end()) {
518  // name is repeated.
519  int32 prev_line = iter->second;
520  if (prev_line >= num_lines_initial) {
521  // user-provided config contained repeat of node with this name.
522  KALDI_ERR << "Node name " << name
523  << " appears twice in the same config file.";
524  }
525  // following assert checks that the config-file generated
526  // from an actual nnet does not contain repeats.. that
527  // would be a bug so check it with assert.
528  KALDI_ASSERT(line >= num_lines_initial);
529  to_remove[prev_line] = true;
530  }
531  node_name_to_most_recent_line[name] = line;
532  }
533  }
534  // Now remove any lines with to_remove[i] = true.
535  std::vector<ConfigLine> config_lines_out;
536  config_lines_out.reserve(num_lines);
537  for (int32 i = 0; i < num_lines; i++) {
538  if (!to_remove[i])
539  config_lines_out.push_back((*config_lines)[i]);
540  }
541  config_lines->swap(config_lines_out);
542 }
kaldi::int32 int32
bool IsValidName(const std::string &name)
Returns true if &#39;name&#39; would be a valid name for a component or node in a nnet3Nnet.
Definition: text-utils.cc:553
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ RemoveSomeNodes()

void RemoveSomeNodes ( const std::vector< int32 > &  nodes_to_remove)

Definition at line 881 of file nnet-nnet.cc.

References Nnet::Check(), kaldi::nnet3::DescriptorTokenize(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kDescriptor, kaldi::nnet3::kDimRange, rnnlm::n, Nnet::node_names_, and Nnet::nodes_.

Referenced by kaldi::nnet3::ReadEditConfig(), and Nnet::RemoveOrphanNodes().

881  {
882  if (nodes_to_remove.empty())
883  return;
884  int32 old_num_nodes = nodes_.size(),
885  new_num_nodes = 0;
886  std::vector<int32> old2new_map(old_num_nodes, 0);
887  for (size_t i = 0; i < nodes_to_remove.size(); i++)
888  old2new_map[nodes_to_remove[i]] = -1;
889  std::vector<NetworkNode> new_nodes;
890  std::vector<std::string> new_node_names;
891  for (int32 n = 0; n < old_num_nodes; n++) {
892  if (old2new_map[n] != -1) {
893  old2new_map[n] = new_num_nodes++;
894  new_nodes.push_back(nodes_[n]);
895  new_node_names.push_back(node_names_[n]);
896  }
897  }
898  for (int32 n = 0; n < new_num_nodes; n++) {
899  if (new_nodes[n].node_type == kDescriptor) {
900  // we need to renumber the node indexes inside the descriptor. It's
901  // easiest to do this by converting back and forth to text format. This
902  // is inefficient, of course, but these graphs are typically quite small.
903  std::ostringstream os;
904  new_nodes[n].descriptor.WriteConfig(os, node_names_);
905  std::vector<std::string> tokens;
906  DescriptorTokenize(os.str(), &tokens);
907  KALDI_ASSERT(!tokens.empty());
908  tokens.push_back("end of input");
909  const std::string *token = &(tokens[0]);
910  Descriptor new_descriptor;
911  // this should work; if it doesn't, there was a programming error.
912  if (!new_nodes[n].descriptor.Parse(new_node_names, &token)) {
913  KALDI_ERR << "Code error removing orphan nodes.";
914  }
915  } else if (new_nodes[n].node_type == kDimRange) {
916  int32 old_node_index = new_nodes[n].u.node_index,
917  new_node_index = old2new_map[old_node_index];
918  KALDI_ASSERT(new_node_index >= 0 && new_node_index <= new_num_nodes);
919  new_nodes[n].u.node_index = new_node_index;
920  }
921  }
922  nodes_ = new_nodes;
923  node_names_ = new_node_names;
924  bool warn_for_orphans = false;
925  // don't warn about orphans, because at this stage we may have
926  // orphan components that will later be removed by calling
927  // RemoveOrphanComponents().
928  Check(warn_for_orphans);
929 }
bool DescriptorTokenize(const std::string &input, std::vector< std::string > *tokens)
This function tokenizes input when parsing Descriptor configuration values.
Definition: nnet-parse.cc:30
kaldi::int32 int32
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Check(bool warn_for_orphans=true) const
Checks the neural network for validity (dimension matches and various other requirements).
Definition: nnet-nnet.cc:694

◆ ResetGenerators()

void ResetGenerators ( )

Definition at line 951 of file nnet-nnet.cc.

References Nnet::GetComponent(), Nnet::NumComponents(), and RandomComponent::ResetGenerator().

951  {
952  // resets random-number generators for all random
953  // components.
954  for (int32 c = 0; c < NumComponents(); c++) {
955  RandomComponent *rc = dynamic_cast<RandomComponent*>(GetComponent(c));
956  if (rc != NULL)
957  rc->ResetGenerator();
958  }
959 }
kaldi::int32 int32
Component * GetComponent(int32 c)
Return component indexed c. Not a copy; not owned by caller.
Definition: nnet-nnet.cc:150
int32 NumComponents() const
Definition: nnet-nnet.h:124

◆ SetComponent()

void SetComponent ( int32  c,
Component component 
)

Replace the component indexed by c with a new component.

Frees previous component indexed by c. Takes ownership of the pointer 'component'.

Definition at line 155 of file nnet-nnet.cc.

References KALDI_ASSERT.

Referenced by kaldi::nnet3::ConvertRepeatedToBlockAffine(), and kaldi::nnet3::ReadEditConfig().

155  {
156  KALDI_ASSERT(static_cast<size_t>(c) < components_.size());
157  delete components_[c];
158  components_[c] = component;
159 }
std::vector< Component * > components_
Definition: nnet-nnet.h:331
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ SetNodeName()

void SetNodeName ( int32  node_index,
const std::string &  new_name 
)

This can be used to modify invidual node names.

Note, this does not affect the neural net structure at all, it just assigns a new name to an existing node while leaving all connections identical.

Definition at line 53 of file nnet-nnet.cc.

References kaldi::IsValidName(), KALDI_ERR, and NetworkNode::node_index.

Referenced by kaldi::nnet3::ReadEditConfig().

53  {
54  if (!(static_cast<size_t>(node_index) < nodes_.size()))
55  KALDI_ERR << "Invalid node index";
56  if (GetNodeIndex(new_name) != -1)
57  KALDI_ERR << "You cannot rename a node to create a duplicate node name";
58  if (!IsValidName(new_name))
59  KALDI_ERR << "Node name " << new_name << " is not allowed.";
60  node_names_[node_index] = new_name;
61 }
bool IsValidName(const std::string &name)
Returns true if &#39;name&#39; would be a valid name for a component or node in a nnet3Nnet.
Definition: text-utils.cc:553
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340
#define KALDI_ERR
Definition: kaldi-error.h:147
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ Swap()

void Swap ( Nnet other)

Definition at line 579 of file nnet-nnet.cc.

References Nnet::component_names_, Nnet::components_, Nnet::node_names_, and Nnet::nodes_.

Referenced by Nnet::Read().

579  {
580  component_names_.swap(other->component_names_);
581  components_.swap(other->components_);
582  node_names_.swap(other->node_names_);
583  nodes_.swap(other->nodes_);
584 }
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
std::vector< std::string > node_names_
Definition: nnet-nnet.h:337
std::vector< NetworkNode > nodes_
Definition: nnet-nnet.h:340

◆ Write()

void Write ( std::ostream &  ostream,
bool  binary 
) const

Definition at line 630 of file nnet-nnet.cc.

References rnnlm::i, KALDI_ASSERT, kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by main(), kaldi::nnet3::UnitTestNnetIo(), and AmNnetSimple::Write().

630  {
631  WriteToken(os, binary, "<Nnet3>");
632  os << std::endl;
633  std::vector<std::string> config_lines;
634  const bool include_dim = false;
635  GetConfigLines(include_dim, &config_lines);
636  for (size_t i = 0; i < config_lines.size(); i++) {
637  KALDI_ASSERT(!config_lines[i].empty());
638  os << config_lines[i] << std::endl;
639  }
640  // A blank line terminates the config-like section of the file.
641  os << std::endl;
642  // Now write the Components
643  int32 num_components = components_.size();
644  WriteToken(os, binary, "<NumComponents>");
645  WriteBasicType(os, binary, num_components);
646  if (!binary)
647  os << std::endl;
648  for (int32 c = 0; c < num_components; c++) {
649  WriteToken(os, binary, "<ComponentName>");
650  WriteToken(os, binary, component_names_[c]);
651  components_[c]->Write(os, binary);
652  if (!binary)
653  os << std::endl;
654  }
655  WriteToken(os, binary, "</Nnet3>");
656 }
kaldi::int32 int32
std::vector< std::string > component_names_
Definition: nnet-nnet.h:326
std::vector< Component * > components_
Definition: nnet-nnet.h:331
void GetConfigLines(bool include_dim, std::vector< std::string > *config_lines) const
Definition: nnet-nnet.cc:180
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ component_names_

std::vector<std::string> component_names_
private

Definition at line 326 of file nnet-nnet.h.

Referenced by Nnet::Info(), Nnet::operator=(), Nnet::RemoveOrphanComponents(), and Nnet::Swap().

◆ components_

std::vector<Component*> components_
private

◆ node_names_

std::vector<std::string> node_names_
private

Definition at line 337 of file nnet-nnet.h.

Referenced by Nnet::operator=(), Nnet::RemoveSomeNodes(), and Nnet::Swap().

◆ nodes_

std::vector<NetworkNode> nodes_
private

The documentation for this class was generated from the following files: