SvdApplier Class Reference
Collaboration diagram for SvdApplier:

Classes

struct  ModifiedComponentInfo
 

Public Member Functions

 SvdApplier (const std::string component_name_pattern, int32 bottleneck_dim, BaseFloat energy_threshold, BaseFloat shrinkage_threshold, Nnet *nnet)
 
void ApplySvd ()
 

Private Member Functions

void DecomposeComponents ()
 
int32 GetReducedDimension (const Vector< BaseFloat > &input_vector, int32 lower, int32 upper, BaseFloat min_val)
 
bool DecomposeComponent (const std::string &component_name, const AffineComponent &affine, Component **component_a_out, Component **component_b_out)
 
void ModifyTopology ()
 

Private Attributes

std::vector< int32modification_index_
 
std::vector< ModifiedComponentInfomodified_component_info_
 
Nnetnnet_
 
int32 bottleneck_dim_
 
BaseFloat energy_threshold_
 
BaseFloat shrinkage_threshold_
 
std::string component_name_pattern_
 

Detailed Description

Definition at line 663 of file nnet-utils.cc.

Constructor & Destructor Documentation

◆ SvdApplier()

SvdApplier ( const std::string  component_name_pattern,
int32  bottleneck_dim,
BaseFloat  energy_threshold,
BaseFloat  shrinkage_threshold,
Nnet nnet 
)
inline

Definition at line 665 of file nnet-utils.cc.

669  : nnet_(nnet),
670  bottleneck_dim_(bottleneck_dim),
671  energy_threshold_(energy_threshold),
672  shrinkage_threshold_(shrinkage_threshold),
673  component_name_pattern_(component_name_pattern) { }
BaseFloat shrinkage_threshold_
Definition: nnet-utils.cc:969
std::string component_name_pattern_
Definition: nnet-utils.cc:970

Member Function Documentation

◆ ApplySvd()

void ApplySvd ( )
inline

Definition at line 674 of file nnet-utils.cc.

References SvdApplier::bottleneck_dim_, SvdApplier::DecomposeComponents(), KALDI_LOG, SvdApplier::modified_component_info_, and SvdApplier::ModifyTopology().

Referenced by kaldi::nnet3::ReadEditConfig().

674  {
676  if (!modified_component_info_.empty())
677  ModifyTopology();
678  KALDI_LOG << "Decomposed " << modified_component_info_.size()
679  << " components with SVD dimension " << bottleneck_dim_;
680  }
std::vector< ModifiedComponentInfo > modified_component_info_
Definition: nnet-utils.cc:963
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ DecomposeComponent()

bool DecomposeComponent ( const std::string &  component_name,
const AffineComponent affine,
Component **  component_a_out,
Component **  component_b_out 
)
inlineprivate

Definition at line 765 of file nnet-utils.cc.

References VectorBase< Real >::AddVec2(), AffineComponent::BiasParams(), SvdApplier::bottleneck_dim_, SvdApplier::energy_threshold_, SvdApplier::GetReducedDimension(), AffineComponent::InputDim(), KALDI_ASSERT, KALDI_LOG, kaldi::kCopyData, AffineComponent::LinearParams(), MatrixBase< Real >::MulColsVec(), AffineComponent::OutputDim(), Matrix< Real >::Resize(), UpdatableComponent::SetUpdatableConfigs(), SvdApplier::shrinkage_threshold_, and kaldi::SortSvd().

Referenced by SvdApplier::DecomposeComponents().

768  {
769  int32 input_dim = affine.InputDim(), output_dim = affine.OutputDim();
770  Matrix<BaseFloat> linear_params(affine.LinearParams());
771  Vector<BaseFloat> bias_params(affine.BiasParams());
772  int32 middle_dim = std::min<int32>(input_dim, output_dim);
773 
774  // note: 'linear_params' is of dimension output_dim by input_dim.
775  Vector<BaseFloat> s(middle_dim);
776  Matrix<BaseFloat> A(middle_dim, input_dim),
777  B(output_dim, middle_dim);
778  linear_params.Svd(&s, &B, &A);
779  // make sure the singular values are sorted from greatest to least value.
780  SortSvd(&s, &B, &A);
781  Vector<BaseFloat> s2(s.Dim());
782  s2.AddVec2(1.0, s);
783  BaseFloat s2_sum_orig = s2.Sum();
786  if (energy_threshold_ > 0) {
787  BaseFloat min_singular_sum = energy_threshold_ * s2_sum_orig;
788  bottleneck_dim_ = GetReducedDimension(s2, 0, s2.Dim()-1, min_singular_sum);
789  }
790  SubVector<BaseFloat> this_part(s2, 0, bottleneck_dim_);
791  BaseFloat s2_sum_reduced = this_part.Sum();
792  BaseFloat shrinkage_ratio =
793  static_cast<BaseFloat>(bottleneck_dim_ * (input_dim+output_dim))
794  / static_cast<BaseFloat>(input_dim * output_dim);
795  if (shrinkage_ratio > shrinkage_threshold_) {
796  KALDI_LOG << "Shrinkage ratio " << shrinkage_ratio
797  << " greater than threshold : " << shrinkage_threshold_
798  << " Skipping SVD for this layer.";
799  return false;
800  }
801 
802  s.Resize(bottleneck_dim_, kCopyData);
803  A.Resize(bottleneck_dim_, input_dim, kCopyData);
804  B.Resize(output_dim, bottleneck_dim_, kCopyData);
805  KALDI_LOG << "For component " << component_name
806  << " singular value squared sum changed by "
807  << (s2_sum_orig - s2_sum_reduced)
808  << " (from " << s2_sum_orig << " to " << s2_sum_reduced << ")";
809  KALDI_LOG << "For component " << component_name
810  << " dimension reduced from "
811  << " (" << input_dim << "," << output_dim << ")"
812  << " to [(" << input_dim << "," << bottleneck_dim_
813  << "), (" << bottleneck_dim_ << "," << output_dim <<")]";
814  KALDI_LOG << "shrinkage ratio : " << shrinkage_ratio;
815 
816  // we'll divide the singular values equally between the two
817  // parameter matrices.
818  s.ApplyPow(0.5);
819  A.MulRowsVec(s);
820  B.MulColsVec(s);
821 
822  CuMatrix<BaseFloat> A_cuda(A), B_cuda(B);
823  CuVector<BaseFloat> bias_params_cuda(bias_params);
824 
825  LinearComponent *component_a = new LinearComponent(A_cuda);
826  NaturalGradientAffineComponent *component_b =
827  new NaturalGradientAffineComponent(B_cuda, bias_params_cuda);
828  // set the learning rates, max-change, and so on.
829  component_a->SetUpdatableConfigs(affine);
830  component_b->SetUpdatableConfigs(affine);
831  *component_a_out = component_a;
832  *component_b_out = component_b;
833  return true;
834  }
BaseFloat shrinkage_threshold_
Definition: nnet-utils.cc:969
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
void SortSvd(VectorBase< Real > *s, MatrixBase< Real > *U, MatrixBase< Real > *Vt, bool sort_on_absolute_value)
Function to ensure that SVD is sorted.
int32 GetReducedDimension(const Vector< BaseFloat > &input_vector, int32 lower, int32 upper, BaseFloat min_val)
Definition: nnet-utils.cc:743

◆ DecomposeComponents()

void DecomposeComponents ( )
inlineprivate

Definition at line 686 of file nnet-utils.cc.

References Nnet::AddComponent(), SvdApplier::bottleneck_dim_, SvdApplier::ModifiedComponentInfo::component_a_index, SvdApplier::ModifiedComponentInfo::component_b_index, SvdApplier::ModifiedComponentInfo::component_index, SvdApplier::ModifiedComponentInfo::component_name, SvdApplier::ModifiedComponentInfo::component_name_a, SvdApplier::ModifiedComponentInfo::component_name_b, SvdApplier::component_name_pattern_, SvdApplier::DecomposeComponent(), Nnet::GetComponent(), Nnet::GetComponentIndex(), Nnet::GetComponentName(), AffineComponent::InputDim(), KALDI_ERR, KALDI_LOG, KALDI_WARN, SvdApplier::modification_index_, SvdApplier::modified_component_info_, rnnlm::n, kaldi::nnet3::NameMatchesPattern(), SvdApplier::nnet_, Nnet::NumComponents(), and AffineComponent::OutputDim().

Referenced by SvdApplier::ApplySvd().

686  {
687  int32 num_components = nnet_->NumComponents();
688  modification_index_.resize(num_components, -1);
689  for (int32 c = 0; c < num_components; c++) {
690  Component *component = nnet_->GetComponent(c);
691  std::string component_name = nnet_->GetComponentName(c);
692  if (NameMatchesPattern(component_name.c_str(),
693  component_name_pattern_.c_str())) {
694  AffineComponent *affine = dynamic_cast<AffineComponent*>(component);
695  if (affine == NULL) {
696  KALDI_WARN << "Not decomposing component " << component_name
697  << " as it is not an AffineComponent.";
698  continue;
699  }
700  int32 input_dim = affine->InputDim(),
701  output_dim = affine->OutputDim();
702  if (input_dim <= bottleneck_dim_ || output_dim <= bottleneck_dim_) {
703  KALDI_WARN << "Not decomposing component " << component_name
704  << " with SVD to rank " << bottleneck_dim_
705  << " because its dimension is " << input_dim
706  << " -> " << output_dim;
707  continue;
708  }
709  Component *component_a = NULL, *component_b = NULL;
710  if (DecomposeComponent(component_name, *affine, &component_a, &component_b)) {
711  size_t n = modified_component_info_.size();
712  modification_index_[c] = n;
713  modified_component_info_.resize(n + 1);
714  ModifiedComponentInfo &info = modified_component_info_[n];
715  info.component_index = c;
716  info.component_name = component_name;
717  info.component_name_a = component_name + "_a";
718  info.component_name_b = component_name + "_b";
719  if (nnet_->GetComponentIndex(info.component_name_a) >= 0)
720  KALDI_ERR << "Neural network already has a component named "
721  << info.component_name_a;
722  if (nnet_->GetComponentIndex(info.component_name_b) >= 0)
723  KALDI_ERR << "Neural network already has a component named "
724  << info.component_name_b;
725  info.component_a_index = nnet_->AddComponent(info.component_name_a,
726  component_a);
727  info.component_b_index = nnet_->AddComponent(info.component_name_b,
728  component_b);
729  }
730  }
731  }
732  KALDI_LOG << "Converted " << modified_component_info_.size()
733  << " components to FixedAffineComponent.";
734  }
int32 AddComponent(const std::string &name, Component *component)
Adds a new component with the given name, which should not be the same as any existing component name...
Definition: nnet-nnet.cc:161
kaldi::int32 int32
std::string component_name_pattern_
Definition: nnet-utils.cc:970
bool NameMatchesPattern(const char *name, const char *pattern)
Definition: nnet-parse.cc:235
struct rnnlm::@11::@12 n
int32 GetComponentIndex(const std::string &node_name) const
returns index associated with this component name, or -1 if no such index.
Definition: nnet-nnet.cc:474
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
const std::string & GetComponentName(int32 component_index) const
returns individual component name.
Definition: nnet-nnet.cc:689
Component * GetComponent(int32 c)
Return component indexed c. Not a copy; not owned by caller.
Definition: nnet-nnet.cc:150
int32 NumComponents() const
Definition: nnet-nnet.h:124
std::vector< ModifiedComponentInfo > modified_component_info_
Definition: nnet-utils.cc:963
#define KALDI_LOG
Definition: kaldi-error.h:153
bool DecomposeComponent(const std::string &component_name, const AffineComponent &affine, Component **component_a_out, Component **component_b_out)
Definition: nnet-utils.cc:765
std::vector< int32 > modification_index_
Definition: nnet-utils.cc:945

◆ GetReducedDimension()

int32 GetReducedDimension ( const Vector< BaseFloat > &  input_vector,
int32  lower,
int32  upper,
BaseFloat  min_val 
)
inlineprivate

Definition at line 743 of file nnet-utils.cc.

References rnnlm::i.

Referenced by SvdApplier::DecomposeComponent().

746  {
747  BaseFloat sum = 0;
748  int32 i = 0;
749  for (i = lower; i <= upper; i++) {
750  sum = sum + input_vector(i);
751  if (sum >= min_val) break;
752  }
753  return (i+1);
754  }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29

◆ ModifyTopology()

void ModifyTopology ( )
inlineprivate

Definition at line 847 of file nnet-utils.cc.

References NetworkNode::component_index, SvdApplier::ModifiedComponentInfo::component_name_a, SvdApplier::ModifiedComponentInfo::component_name_b, NetworkNode::descriptor, NetworkNode::dim, NetworkNode::dim_offset, Nnet::GetComponentName(), Nnet::GetNode(), Nnet::GetNodeNames(), Nnet::IsComponentInputNode(), Nnet::IsComponentNode(), Nnet::IsInputNode(), Nnet::IsOutputNode(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kComponent, kaldi::nnet3::kDescriptor, kaldi::nnet3::kDimRange, kaldi::nnet3::kLinear, SvdApplier::modification_index_, SvdApplier::modified_component_info_, rnnlm::n, SvdApplier::nnet_, NetworkNode::node_index, NetworkNode::node_type, Nnet::NumNodes(), NetworkNode::objective_type, Nnet::ReadConfig(), Nnet::RemoveOrphanComponents(), Nnet::RemoveOrphanNodes(), NetworkNode::u, and Descriptor::WriteConfig().

Referenced by SvdApplier::ApplySvd().

847  {
848  std::set<int32> nodes_to_modify;
849  std::vector<std::string> node_names_orig = nnet_->GetNodeNames(),
850  node_names_modified = node_names_orig;
851 
852  // The following loop sets up 'nodes_to_modify' and 'node_names_modified'.
853  for (int32 n = 0; n < nnet_->NumNodes(); n++) {
854  if (nnet_->IsComponentNode(n)) {
855  NetworkNode &node = nnet_->GetNode(n);
856  int32 component_index = node.u.component_index,
857  modification_index = modification_index_[component_index];
858  if (modification_index >= 0) {
859  // This is a component-node for one of the components that we're
860  // splitting in two.
861  nodes_to_modify.insert(n);
862  std::string node_name = node_names_orig[n],
863  node_name_b = node_name + "_b";
864  node_names_modified[n] = node_name_b;
865  }
866  }
867  }
868 
869 
870  // config_os is a stream to which we are printing lines that we'll later
871  // read using nnet_->ReadConfig().
872  std::ostringstream config_os;
873  // The following loop writes to 'config_os'. The the code is modified from
874  // the private function Nnet::GetAsConfigLine(), and from
875  // Nnet::GetConfigLines().
876  for (int32 n = 0; n < nnet_->NumNodes(); n++) {
878  // component-input descriptor nodes aren't handled separately from their
879  // associated components (we deal with them along with their
880  // component-node); and input-nodes won't be affected so we don't have
881  // to print anything.
882  continue;
883  }
884  const NetworkNode &node = nnet_->GetNode(n);
885  int32 c = node.u.component_index; // 'c' will only be meaningful if the
886  // node is a component-node.
887  std::string node_name = node_names_orig[n];
888  if (node.node_type == kComponent && modification_index_[c] >= 0) {
889  ModifiedComponentInfo &info = modified_component_info_[
891  std::string node_name_a = node_name + "_a",
892  node_name_b = node_name + "_b";
893  // we print two component-nodes, the "a" an "b". The original
894  // one will later be removed when we call RemoveOrphanNodes().
895  config_os << "component-node name=" << node_name_a << " component="
896  << info.component_name_a << " input=";
897  nnet_->GetNode(n-1).descriptor.WriteConfig(config_os, node_names_modified);
898  config_os << "\n";
899  config_os << "component-node name=" << node_name_b << " component="
900  << info.component_name_b << " input=" << node_name_a << "\n";
901  } else {
902  // This code is modified from Nnet::GetAsConfigLine(). The key difference
903  // is that we're using node_names_modified, which will replace all the
904  // nodes we're splitting with their "b" versions.
905  switch (node.node_type) {
906  case kDescriptor:
907  // assert that it's an output-descriptor, not one describing the input to
908  // a component-node.
910  config_os << "output-node name=" << node_name << " input=";
911  node.descriptor.WriteConfig(config_os, node_names_modified);
912  config_os << " objective=" << (node.u.objective_type == kLinear ?
913  "linear" : "quadratic");
914  break;
915  case kComponent:
916  config_os << "component-node name=" << node_name << " component="
917  << nnet_->GetComponentName(node.u.component_index)
918  << " input=";
919  nnet_->GetNode(n-1).descriptor.WriteConfig(config_os,
920  node_names_modified);
921  break;
922  case kDimRange:
923  config_os << "dim-range-node name=" << node_name << " input-node="
924  << node_names_modified[node.u.node_index]
925  << " dim-offset=" << node.dim_offset
926  << " dim=" << node.dim;
927  break;
928  default:
929  KALDI_ERR << "Unexpected node type.";
930  }
931  config_os << "\n";
932  }
933  }
934  std::istringstream config_is(config_os.str());
935  nnet_->ReadConfig(config_is);
938  }
int32 NumNodes() const
Definition: nnet-nnet.h:126
void ReadConfig(std::istream &config_file)
Definition: nnet-nnet.cc:189
bool IsInputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kInput.
Definition: nnet-nnet.cc:120
kaldi::int32 int32
bool IsComponentNode(int32 node) const
Returns true if this is a component node, meaning that it is of type kComponent.
Definition: nnet-nnet.cc:132
void RemoveOrphanComponents()
Definition: nnet-nnet.cc:844
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
Definition: nnet-nnet.h:146
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
struct rnnlm::@11::@12 n
void RemoveOrphanNodes(bool remove_orphan_inputs=false)
Definition: nnet-nnet.cc:932
#define KALDI_ERR
Definition: kaldi-error.h:147
const std::string & GetComponentName(int32 component_index) const
returns individual component name.
Definition: nnet-nnet.cc:689
void WriteConfig(std::ostream &os, const std::vector< std::string > &node_names) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< ModifiedComponentInfo > modified_component_info_
Definition: nnet-utils.cc:963
std::vector< int32 > modification_index_
Definition: nnet-utils.cc:945
const std::vector< std::string > & GetNodeNames() const
returns vector of node names (needed by some parsing code, for instance).
Definition: nnet-nnet.cc:63
bool IsComponentInputNode(int32 node) const
Returns true if this is component-input node, i.e.
Definition: nnet-nnet.cc:172

Member Data Documentation

◆ bottleneck_dim_

int32 bottleneck_dim_
private

◆ component_name_pattern_

std::string component_name_pattern_
private

Definition at line 970 of file nnet-utils.cc.

Referenced by SvdApplier::DecomposeComponents().

◆ energy_threshold_

BaseFloat energy_threshold_
private

Definition at line 968 of file nnet-utils.cc.

Referenced by SvdApplier::DecomposeComponent().

◆ modification_index_

std::vector<int32> modification_index_
private

Definition at line 945 of file nnet-utils.cc.

Referenced by SvdApplier::DecomposeComponents(), and SvdApplier::ModifyTopology().

◆ modified_component_info_

std::vector<ModifiedComponentInfo> modified_component_info_
private

◆ nnet_

◆ shrinkage_threshold_

BaseFloat shrinkage_threshold_
private

Definition at line 969 of file nnet-utils.cc.

Referenced by SvdApplier::DecomposeComponent().


The documentation for this class was generated from the following file: