Public Member Functions | |
ComputationExpander (const Nnet &nnet, const MiscComputationInfo &misc_info, const NnetComputation &computation, bool need_debug_info, int32 num_n_values, NnetComputation *expanded_computation) | |
void | Expand () |
Private Member Functions | |
void | InitStrideInfo () |
void | ComputeMatrixInfo () |
void | ComputeDebugInfo () |
void | ComputeSubmatrixInfo () |
void | ExpandRowsCommand (const NnetComputation::Command &c_in, NnetComputation::Command *c_out) |
void | ExpandRowsMultiCommand (const NnetComputation::Command &c_in, NnetComputation::Command *c_out) |
void | ExpandRowRangesCommand (const NnetComputation::Command &c_in, NnetComputation::Command *c_out) |
void | ComputePrecomputedIndexes () |
void | ComputeCommands () |
void | EnsureDebugInfoExists (int32 submatrix_index) |
bool | GetNewSubmatLocationInfo (int32 submat_index, int32 old_row_index, int32 *new_row_index, int32 *n_stride) const |
int32 | GetNewMatrixLocationInfo (int32 old_matrix_index, int32 old_row_index) const |
This function is used in mapping row-indexes into matrices, from the old to the new computation. More... | |
void | ExpandIndexes (const std::vector< Index > &indexes, std::vector< Index > *indexes_expanded) const |
Private Attributes | |
std::vector< int32 > | n_stride_ |
const Nnet & | nnet_ |
const MiscComputationInfo & | misc_info_ |
const NnetComputation & | computation_ |
bool | need_debug_info_ |
int32 | num_n_values_ |
NnetComputation * | expanded_computation_ |
Definition at line 3138 of file nnet-optimize-utils.cc.
|
inline |
Definition at line 3140 of file nnet-optimize-utils.cc.
References KALDI_ASSERT.
|
private |
Definition at line 3501 of file nnet-optimize-utils.cc.
References NnetComputation::Command::command_type, NnetComputation::commands, DerivativeTimeLimiter::computation_, kaldi::nnet3::kAcceptInput, kaldi::nnet3::kAddRowRanges, kaldi::nnet3::kAddRows, kaldi::nnet3::kAddRowsMulti, kaldi::nnet3::kAddToRowsMulti, KALDI_ERR, kaldi::nnet3::kAllocMatrix, kaldi::nnet3::kBackprop, kaldi::nnet3::kBackpropNoModelUpdate, kaldi::nnet3::kCompressMatrix, kaldi::nnet3::kCopyRows, kaldi::nnet3::kCopyRowsMulti, kaldi::nnet3::kCopyToRowsMulti, kaldi::nnet3::kDeallocMatrix, kaldi::nnet3::kDecompressMatrix, kaldi::nnet3::kGotoLabel, kaldi::nnet3::kMatrixAdd, kaldi::nnet3::kMatrixCopy, kaldi::nnet3::kNoOperation, kaldi::nnet3::kNoOperationLabel, kaldi::nnet3::kNoOperationMarker, kaldi::nnet3::kNoOperationPermanent, kaldi::nnet3::kPropagate, kaldi::nnet3::kProvideOutput, kaldi::nnet3::kSetConst, and kaldi::nnet3::kSwapMatrix.
|
private |
Definition at line 3603 of file nnet-optimize-utils.cc.
References NnetComputation::MatrixDebugInfo::cindexes, DerivativeTimeLimiter::computation_, NnetComputation::MatrixDebugInfo::is_deriv, KALDI_ASSERT, NnetComputation::matrices, NnetComputation::matrix_debug_info, and rnnlm::n.
|
private |
Definition at line 3588 of file nnet-optimize-utils.cc.
References DerivativeTimeLimiter::computation_, and NnetComputation::matrices.
|
private |
Definition at line 3676 of file nnet-optimize-utils.cc.
References NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::command_type, NnetComputation::commands, NnetComputation::component_precomputed_indexes, DerivativeTimeLimiter::computation_, NnetComputation::PrecomputedIndexesInfo::data, Nnet::GetComponent(), NnetComputation::PrecomputedIndexesInfo::input_indexes, KALDI_ASSERT, kaldi::nnet3::kBackprop, kaldi::nnet3::kBackpropNoModelUpdate, kaldi::nnet3::kPropagate, DerivativeTimeLimiter::nnet_, NnetComputation::PrecomputedIndexesInfo::output_indexes, and Component::PrecomputeIndexes().
|
private |
Definition at line 3637 of file nnet-optimize-utils.cc.
References NnetComputation::MatrixDebugInfo::cindexes, NnetComputation::SubMatrixInfo::col_offset, DerivativeTimeLimiter::computation_, NnetComputation::GetSubmatrixStrings(), KALDI_ERR, NnetComputation::matrix_debug_info, NnetComputation::SubMatrixInfo::matrix_index, DerivativeTimeLimiter::nnet_, NnetComputation::SubMatrixInfo::num_cols, NnetComputation::SubMatrixInfo::num_rows, NnetComputation::Print(), NnetComputation::SubMatrixInfo::row_offset, and NnetComputation::submatrices.
|
private |
void Expand | ( | ) |
Definition at line 3573 of file nnet-optimize-utils.cc.
References DerivativeTimeLimiter::computation_, and NnetComputation::need_model_derivative.
Referenced by kaldi::nnet3::ExpandComputation().
|
private |
Definition at line 3794 of file nnet-optimize-utils.cc.
References kaldi::nnet3::ConvertNumNValues(), kaldi::nnet3::FindNStride(), and KALDI_ASSERT.
|
private |
Definition at line 3426 of file nnet-optimize-utils.cc.
References NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::arg3, DerivativeTimeLimiter::computation_, NnetComputation::indexes_ranges, KALDI_ASSERT, rnnlm::n, and NnetComputation::submatrices.
|
private |
Definition at line 3298 of file nnet-optimize-utils.cc.
References NnetComputation::Command::alpha, NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::arg3, DerivativeTimeLimiter::computation_, NnetComputation::indexes, KALDI_ASSERT, rnnlm::n, and NnetComputation::submatrices.
|
private |
Definition at line 3361 of file nnet-optimize-utils.cc.
References NnetComputation::Command::arg1, NnetComputation::Command::arg2, DerivativeTimeLimiter::computation_, NnetComputation::indexes_multi, KALDI_ASSERT, rnnlm::n, and NnetComputation::submatrices.
This function is used in mapping row-indexes into matrices, from the old to the new computation.
[in] | matrix_index | The matrix-index > 0, for which we are mapping row-indexes. The matrix-indexes are the same in the old and new computations. |
[in] | old_row_index | The old row-index into the matrix. |
Definition at line 3761 of file nnet-optimize-utils.cc.
References DerivativeTimeLimiter::computation_, KALDI_ASSERT, and NnetComputation::matrix_debug_info.
|
private |
Definition at line 3743 of file nnet-optimize-utils.cc.
References NnetComputation::MatrixDebugInfo::cindexes, DerivativeTimeLimiter::computation_, NnetComputation::matrix_debug_info, and NnetComputation::submatrices.
|
private |
Definition at line 3548 of file nnet-optimize-utils.cc.
References NnetComputation::MatrixDebugInfo::cindexes, DerivativeTimeLimiter::computation_, kaldi::nnet3::FindNStride(), KALDI_ASSERT, KALDI_ERR, NnetComputation::matrices, and NnetComputation::matrix_debug_info.
|
private |
Definition at line 3290 of file nnet-optimize-utils.cc.
|
private |
Definition at line 3293 of file nnet-optimize-utils.cc.
|
private |
Definition at line 3289 of file nnet-optimize-utils.cc.
|
private |
Definition at line 3286 of file nnet-optimize-utils.cc.
|
private |
Definition at line 3291 of file nnet-optimize-utils.cc.
|
private |
Definition at line 3288 of file nnet-optimize-utils.cc.
|
private |
Definition at line 3292 of file nnet-optimize-utils.cc.