Classes | |
struct | ModifiedComponentInfo |
Public Member Functions | |
SvdApplier (const std::string component_name_pattern, int32 bottleneck_dim, BaseFloat energy_threshold, BaseFloat shrinkage_threshold, Nnet *nnet) | |
void | ApplySvd () |
Private Member Functions | |
void | DecomposeComponents () |
int32 | GetReducedDimension (const Vector< BaseFloat > &input_vector, int32 lower, int32 upper, BaseFloat min_val) |
bool | DecomposeComponent (const std::string &component_name, const AffineComponent &affine, Component **component_a_out, Component **component_b_out) |
void | ModifyTopology () |
Private Attributes | |
std::vector< int32 > | modification_index_ |
std::vector< ModifiedComponentInfo > | modified_component_info_ |
Nnet * | nnet_ |
int32 | bottleneck_dim_ |
BaseFloat | energy_threshold_ |
BaseFloat | shrinkage_threshold_ |
std::string | component_name_pattern_ |
Definition at line 663 of file nnet-utils.cc.
|
inline |
Definition at line 665 of file nnet-utils.cc.
|
inline |
Definition at line 674 of file nnet-utils.cc.
References SvdApplier::bottleneck_dim_, SvdApplier::DecomposeComponents(), KALDI_LOG, SvdApplier::modified_component_info_, and SvdApplier::ModifyTopology().
Referenced by kaldi::nnet3::ReadEditConfig().
|
inlineprivate |
Definition at line 765 of file nnet-utils.cc.
References VectorBase< Real >::AddVec2(), AffineComponent::BiasParams(), SvdApplier::bottleneck_dim_, SvdApplier::energy_threshold_, SvdApplier::GetReducedDimension(), AffineComponent::InputDim(), KALDI_ASSERT, KALDI_LOG, kaldi::kCopyData, AffineComponent::LinearParams(), MatrixBase< Real >::MulColsVec(), AffineComponent::OutputDim(), Matrix< Real >::Resize(), UpdatableComponent::SetUpdatableConfigs(), SvdApplier::shrinkage_threshold_, and kaldi::SortSvd().
Referenced by SvdApplier::DecomposeComponents().
|
inlineprivate |
Definition at line 686 of file nnet-utils.cc.
References Nnet::AddComponent(), SvdApplier::bottleneck_dim_, SvdApplier::ModifiedComponentInfo::component_a_index, SvdApplier::ModifiedComponentInfo::component_b_index, SvdApplier::ModifiedComponentInfo::component_index, SvdApplier::ModifiedComponentInfo::component_name, SvdApplier::ModifiedComponentInfo::component_name_a, SvdApplier::ModifiedComponentInfo::component_name_b, SvdApplier::component_name_pattern_, SvdApplier::DecomposeComponent(), Nnet::GetComponent(), Nnet::GetComponentIndex(), Nnet::GetComponentName(), AffineComponent::InputDim(), KALDI_ERR, KALDI_LOG, KALDI_WARN, SvdApplier::modification_index_, SvdApplier::modified_component_info_, rnnlm::n, kaldi::nnet3::NameMatchesPattern(), SvdApplier::nnet_, Nnet::NumComponents(), and AffineComponent::OutputDim().
Referenced by SvdApplier::ApplySvd().
|
inlineprivate |
Definition at line 743 of file nnet-utils.cc.
References rnnlm::i.
Referenced by SvdApplier::DecomposeComponent().
|
inlineprivate |
Definition at line 847 of file nnet-utils.cc.
References NetworkNode::component_index, SvdApplier::ModifiedComponentInfo::component_name_a, SvdApplier::ModifiedComponentInfo::component_name_b, NetworkNode::descriptor, NetworkNode::dim, NetworkNode::dim_offset, Nnet::GetComponentName(), Nnet::GetNode(), Nnet::GetNodeNames(), Nnet::IsComponentInputNode(), Nnet::IsComponentNode(), Nnet::IsInputNode(), Nnet::IsOutputNode(), KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kComponent, kaldi::nnet3::kDescriptor, kaldi::nnet3::kDimRange, kaldi::nnet3::kLinear, SvdApplier::modification_index_, SvdApplier::modified_component_info_, rnnlm::n, SvdApplier::nnet_, NetworkNode::node_index, NetworkNode::node_type, Nnet::NumNodes(), NetworkNode::objective_type, Nnet::ReadConfig(), Nnet::RemoveOrphanComponents(), Nnet::RemoveOrphanNodes(), NetworkNode::u, and Descriptor::WriteConfig().
Referenced by SvdApplier::ApplySvd().
|
private |
Definition at line 967 of file nnet-utils.cc.
Referenced by SvdApplier::ApplySvd(), SvdApplier::DecomposeComponent(), and SvdApplier::DecomposeComponents().
|
private |
Definition at line 970 of file nnet-utils.cc.
Referenced by SvdApplier::DecomposeComponents().
|
private |
Definition at line 968 of file nnet-utils.cc.
Referenced by SvdApplier::DecomposeComponent().
|
private |
Definition at line 945 of file nnet-utils.cc.
Referenced by SvdApplier::DecomposeComponents(), and SvdApplier::ModifyTopology().
|
private |
Definition at line 963 of file nnet-utils.cc.
Referenced by SvdApplier::ApplySvd(), SvdApplier::DecomposeComponents(), and SvdApplier::ModifyTopology().
|
private |
Definition at line 966 of file nnet-utils.cc.
Referenced by ModelCollapser::Collapse(), ModelCollapser::CollapseComponentsAffine(), ModelCollapser::CollapseComponentsBatchnorm(), ModelCollapser::CollapseComponentsDropout(), ModelCollapser::CollapseComponentsScale(), SvdApplier::DecomposeComponents(), ModelCollapser::GetDiagonallyPreModifiedComponentIndex(), ModelCollapser::GetScaledComponentIndex(), SvdApplier::ModifyTopology(), ModelCollapser::OptimizeNode(), and ModelCollapser::ReplaceNodeInDescriptor().
|
private |
Definition at line 969 of file nnet-utils.cc.
Referenced by SvdApplier::DecomposeComponent().