This class is responsible for merging matrices, although you probably want to access it via the the function VariableMergingOptimization(). More...
#include <nnet-optimize-utils.h>
Public Member Functions | |
VariableMergingOptimizer (const NnetOptimizeOptions &config, const Nnet &nnet, NnetComputation *computation) | |
bool | MergeVariables () |
Private Member Functions | |
std::pair< bool, bool > | MayBeMerged (int32 command, int32 s1, int32 s2) const |
This function returns a pair of bools saying whether we can do a (left and/or right) merge respectively, based on the conditions defined in the header. More... | |
void | DoMerge (int32 command_index, int32 s_to_keep, int32 m_to_discard) |
void | MarkAsDirty (int32 s) |
Marks the variables underlying submatrix 's' as dirty. More... | |
void | Initialize () |
Private Attributes | |
const NnetOptimizeOptions & | config_ |
const Nnet & | nnet_ |
NnetComputation * | computation_ |
Analyzer | analyzer_ |
std::vector< std::vector< int32 > > | matrix_to_submatrix_ |
std::vector< bool > | variable_dirty_ |
bool | already_called_merge_variables_ |
This class is responsible for merging matrices, although you probably want to access it via the the function VariableMergingOptimization().
We identify pairs of submatrices which can potentially be merged into a single submatrix.
Suppose there are two different submatrices s1 != s2 that are submatrices of different respective matrices m1 != m2, and somewhere in the computation we have a command C, which is one of: (a) the assignment command "s2 = s1", or (b) a propagate command with s1 as input and s2 as output, with a component that supports propagate in place, or (c) a backprop command with s1 as output-deriv and s2 as input-deriv, with a component that supports backprop in place.
Then the triple (C, s1, s2) is a candidate for merging. We support two types of merging: 'right merging', in which we delete s1 and use s2 instead; and 'left merging' in which we delete s2 and use s1 instead. The two types of merging may seem to be essentially equivalent, but they they are not because in general s1 and s2 may be sub-matrices of larger matrices.
Note the following definitions:
The conditions that must be satisfied for merges are as follows:
If the command C is case (a), i.e. an assignment operation, then the following conditions must apply:
We can explain the procedure for both left-merge and right-merge in one, because it's the same. Define s_to_keep and m_to_keep as s1 and m1 if we're left-merging and s2 and m2 if we're right-merging, and s_to_discard and m_to_discard the opposite way.
The procedure to merge in general is as follows:
At the end when we call RemoveOrphanMatrices(), the renumbering code will automatically detect that there are duplicate submatrices, and will merge them, as well as removing the now-unused matrix indexes. After merging, we will mark the variables (i.e. row-ranges) underlying s1 and s2 as being "dirty" so they can no longer be merged during the lifetime of this class– this is so we don't have to think to hard; we apply this optimization multiple times until it makes no change (see nnet-optimize.cc:VariableMerginOptimization()).
Definition at line 133 of file nnet-optimize-utils.h.
VariableMergingOptimizer | ( | const NnetOptimizeOptions & | config, |
const Nnet & | nnet, | ||
NnetComputation * | computation | ||
) |
Definition at line 711 of file nnet-optimize-utils.cc.
References VariableMergingOptimizer::analyzer_, VariableMergingOptimizer::computation_, kaldi::nnet3::ComputeMatrixToSubmatrix(), Analyzer::Init(), VariableMergingOptimizer::matrix_to_submatrix_, ComputationVariables::NumVariables(), VariableMergingOptimizer::variable_dirty_, and Analyzer::variables.
Definition at line 819 of file nnet-optimize-utils.cc.
References NnetComputation::Command::alpha, VariableMergingOptimizer::analyzer_, NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::command_type, NnetComputation::commands, VariableMergingOptimizer::computation_, ComputationAnalysis::FirstNontrivialMatrixAccess(), kaldi::nnet3::GetSubMatrixOfSubMatrix(), kaldi::nnet3::kAcceptInput, KALDI_ASSERT, kaldi::nnet3::kMatrixCopy, kaldi::nnet3::kNoOperation, kaldi::nnet3::kSetConst, kaldi::kStrideEqualNumCols, VariableMergingOptimizer::MarkAsDirty(), NnetComputation::matrices, Analyzer::matrix_accesses, VariableMergingOptimizer::matrix_to_submatrix_, and NnetComputation::submatrices.
Referenced by VariableMergingOptimizer::MergeVariables().
|
private |
|
private |
Marks the variables underlying submatrix 's' as dirty.
Definition at line 807 of file nnet-optimize-utils.cc.
References VariableMergingOptimizer::analyzer_, ComputationVariables::AppendVariablesForSubmatrix(), KALDI_ASSERT, VariableMergingOptimizer::variable_dirty_, and Analyzer::variables.
Referenced by VariableMergingOptimizer::DoMerge().
This function returns a pair of bools saying whether we can do a (left and/or right) merge respectively, based on the conditions defined in the header.
Note: if one of the variables underlying s1 or s2 is marked as 'dirty' due to a previous merge, this function will return (false,false). The terms left-merge and right-merge are defined in the extended comment above this class. Note: left_merge will always be false if config.allow_left_merge == false, and the same respectively for right_merge.
command | [in] The command-index that assigns s2 := s1 or does a forward or backprop with s1 as the input and s2 as the output |
s1 | [in] A submatrix-index s1 > 0. |
s2 | [in] A submatrix-index s2 > 0 |
Definition at line 945 of file nnet-optimize-utils.cc.
References NnetOptimizeOptions::allow_left_merge, NnetOptimizeOptions::allow_right_merge, VariableMergingOptimizer::analyzer_, ComputationVariables::AppendVariablesForSubmatrix(), NnetComputation::commands, VariableMergingOptimizer::computation_, VariableMergingOptimizer::config_, ComputationAnalysis::DataInvalidatedCommand(), ComputationAnalysis::FirstNontrivialAccess(), MatrixAccesses::is_input, MatrixAccesses::is_output, NnetComputation::IsWholeMatrix(), KALDI_ASSERT, kaldi::nnet3::kMatrixCopy, kaldi::kStrideEqualNumCols, ComputationAnalysis::LastAccess(), ComputationAnalysis::LastWriteAccess(), NnetComputation::matrices, Analyzer::matrix_accesses, NnetComputation::submatrices, VariableMergingOptimizer::variable_dirty_, and Analyzer::variables.
Referenced by VariableMergingOptimizer::MergeVariables().
bool MergeVariables | ( | ) |
Definition at line 723 of file nnet-optimize-utils.cc.
References VariableMergingOptimizer::already_called_merge_variables_, NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::arg3, NnetComputation::Command::arg4, NnetComputation::Command::arg5, NnetComputation::Command::arg6, NnetOptimizeOptions::backprop_in_place, NnetComputation::Command::command_type, NnetComputation::commands, VariableMergingOptimizer::computation_, VariableMergingOptimizer::config_, VariableMergingOptimizer::DoMerge(), Nnet::GetComponent(), KALDI_ASSERT, kaldi::nnet3::kBackprop, kaldi::nnet3::kBackpropInPlace, kaldi::nnet3::kBackpropNoModelUpdate, kaldi::nnet3::kMatrixCopy, kaldi::nnet3::kPropagate, kaldi::nnet3::kPropagateInPlace, VariableMergingOptimizer::MayBeMerged(), VariableMergingOptimizer::nnet_, NnetOptimizeOptions::optimize, NnetOptimizeOptions::propagate_in_place, Component::Properties(), NnetOptimizeOptions::remove_assignments, kaldi::nnet3::RemoveNoOps(), and kaldi::nnet3::RenumberComputation().
Referenced by kaldi::nnet3::VariableMergingOptimization().
|
private |
Definition at line 184 of file nnet-optimize-utils.h.
Referenced by VariableMergingOptimizer::MergeVariables().
|
private |
Definition at line 175 of file nnet-optimize-utils.h.
Referenced by VariableMergingOptimizer::DoMerge(), VariableMergingOptimizer::MarkAsDirty(), VariableMergingOptimizer::MayBeMerged(), and VariableMergingOptimizer::VariableMergingOptimizer().
|
private |
Definition at line 173 of file nnet-optimize-utils.h.
Referenced by VariableMergingOptimizer::DoMerge(), VariableMergingOptimizer::MayBeMerged(), VariableMergingOptimizer::MergeVariables(), and VariableMergingOptimizer::VariableMergingOptimizer().
|
private |
Definition at line 171 of file nnet-optimize-utils.h.
Referenced by VariableMergingOptimizer::MayBeMerged(), and VariableMergingOptimizer::MergeVariables().
|
private |
Definition at line 178 of file nnet-optimize-utils.h.
Referenced by VariableMergingOptimizer::DoMerge(), and VariableMergingOptimizer::VariableMergingOptimizer().
|
private |
Definition at line 172 of file nnet-optimize-utils.h.
Referenced by VariableMergingOptimizer::MergeVariables().
|
private |
Definition at line 182 of file nnet-optimize-utils.h.
Referenced by VariableMergingOptimizer::MarkAsDirty(), VariableMergingOptimizer::MayBeMerged(), and VariableMergingOptimizer::VariableMergingOptimizer().