ComputationExpander Class Reference
Collaboration diagram for ComputationExpander:

Public Member Functions

 ComputationExpander (const Nnet &nnet, const MiscComputationInfo &misc_info, const NnetComputation &computation, bool need_debug_info, int32 num_n_values, NnetComputation *expanded_computation)
 
void Expand ()
 

Private Member Functions

void InitStrideInfo ()
 
void ComputeMatrixInfo ()
 
void ComputeDebugInfo ()
 
void ComputeSubmatrixInfo ()
 
void ExpandRowsCommand (const NnetComputation::Command &c_in, NnetComputation::Command *c_out)
 
void ExpandRowsMultiCommand (const NnetComputation::Command &c_in, NnetComputation::Command *c_out)
 
void ExpandRowRangesCommand (const NnetComputation::Command &c_in, NnetComputation::Command *c_out)
 
void ComputePrecomputedIndexes ()
 
void ComputeCommands ()
 
void EnsureDebugInfoExists (int32 submatrix_index)
 
bool GetNewSubmatLocationInfo (int32 submat_index, int32 old_row_index, int32 *new_row_index, int32 *n_stride) const
 
int32 GetNewMatrixLocationInfo (int32 old_matrix_index, int32 old_row_index) const
 This function is used in mapping row-indexes into matrices, from the old to the new computation. More...
 
void ExpandIndexes (const std::vector< Index > &indexes, std::vector< Index > *indexes_expanded) const
 

Private Attributes

std::vector< int32n_stride_
 
const Nnetnnet_
 
const MiscComputationInfomisc_info_
 
const NnetComputationcomputation_
 
bool need_debug_info_
 
int32 num_n_values_
 
NnetComputationexpanded_computation_
 

Detailed Description

Definition at line 3138 of file nnet-optimize-utils.cc.

Constructor & Destructor Documentation

◆ ComputationExpander()

ComputationExpander ( const Nnet nnet,
const MiscComputationInfo misc_info,
const NnetComputation computation,
bool  need_debug_info,
int32  num_n_values,
NnetComputation expanded_computation 
)
inline

Definition at line 3140 of file nnet-optimize-utils.cc.

References KALDI_ASSERT.

3145  :
3146  nnet_(nnet), misc_info_(misc_info),
3147  computation_(computation),
3148  need_debug_info_(need_debug_info),
3149  num_n_values_(num_n_values),
3150  expanded_computation_(expanded_computation) {
3151  KALDI_ASSERT(num_n_values > 2);
3152  }
const MiscComputationInfo & misc_info_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Function Documentation

◆ ComputeCommands()

void ComputeCommands ( )
private

Definition at line 3501 of file nnet-optimize-utils.cc.

References NnetComputation::Command::command_type, NnetComputation::commands, DerivativeTimeLimiter::computation_, kaldi::nnet3::kAcceptInput, kaldi::nnet3::kAddRowRanges, kaldi::nnet3::kAddRows, kaldi::nnet3::kAddRowsMulti, kaldi::nnet3::kAddToRowsMulti, KALDI_ERR, kaldi::nnet3::kAllocMatrix, kaldi::nnet3::kBackprop, kaldi::nnet3::kBackpropNoModelUpdate, kaldi::nnet3::kCompressMatrix, kaldi::nnet3::kCopyRows, kaldi::nnet3::kCopyRowsMulti, kaldi::nnet3::kCopyToRowsMulti, kaldi::nnet3::kDeallocMatrix, kaldi::nnet3::kDecompressMatrix, kaldi::nnet3::kGotoLabel, kaldi::nnet3::kMatrixAdd, kaldi::nnet3::kMatrixCopy, kaldi::nnet3::kNoOperation, kaldi::nnet3::kNoOperationLabel, kaldi::nnet3::kNoOperationMarker, kaldi::nnet3::kNoOperationPermanent, kaldi::nnet3::kPropagate, kaldi::nnet3::kProvideOutput, kaldi::nnet3::kSetConst, and kaldi::nnet3::kSwapMatrix.

3501  {
3502  int32 num_commands = computation_.commands.size();
3503  expanded_computation_->commands.resize(num_commands);
3504  for (int32 command_index = 0; command_index < num_commands;
3505  command_index++) {
3506  const NnetComputation::Command &c = computation_.commands[command_index];
3507  NnetComputation::Command &c_out =
3508  expanded_computation_->commands[command_index];
3509  c_out = c;
3510  // Commands that only operate on submatrices, components and
3511  // precomputed-indexes do not have to be changed because we'll take care of
3512  // the expansion by suitably redefining the matrices and submatrices, and
3513  // recreating the precomputed-indexes.
3514  // However, commands that require, 'indexes', 'indexes_multi' or
3515  // 'indexes_ranges' do need to be modified.
3516  switch (c.command_type) {
3517  case kAllocMatrix:
3518  case kDeallocMatrix:
3519  case kSetConst:
3520  case kSwapMatrix:
3521  case kPropagate: case kBackprop:
3523  break;
3524  case kCopyRows: case kAddRows:
3525  ExpandRowsCommand(c, &c_out);
3526  break;
3527  case kCopyRowsMulti: case kAddRowsMulti:
3528  case kCopyToRowsMulti: case kAddToRowsMulti:
3529  ExpandRowsMultiCommand(c, &c_out);
3530  break;
3531  case kAddRowRanges:
3532  ExpandRowRangesCommand(c, &c_out);
3533  break;
3535  case kAcceptInput: case kProvideOutput: case kNoOperation:
3537  case kNoOperationLabel: case kGotoLabel:
3538  break;
3539  default:
3540  KALDI_ERR << "Un-handled command type";
3541  }
3542  }
3543 }
void ExpandRowsMultiCommand(const NnetComputation::Command &c_in, NnetComputation::Command *c_out)
void ExpandRowRangesCommand(const NnetComputation::Command &c_in, NnetComputation::Command *c_out)
kaldi::int32 int32
std::vector< Command > commands
#define KALDI_ERR
Definition: kaldi-error.h:147
void ExpandRowsCommand(const NnetComputation::Command &c_in, NnetComputation::Command *c_out)

◆ ComputeDebugInfo()

void ComputeDebugInfo ( )
private

Definition at line 3603 of file nnet-optimize-utils.cc.

References NnetComputation::MatrixDebugInfo::cindexes, DerivativeTimeLimiter::computation_, NnetComputation::MatrixDebugInfo::is_deriv, KALDI_ASSERT, NnetComputation::matrices, NnetComputation::matrix_debug_info, and rnnlm::n.

3603  {
3604  int32 num_matrices = computation_.matrices.size();
3605  KALDI_ASSERT(computation_.matrix_debug_info.size() == num_matrices);
3606  expanded_computation_->matrix_debug_info.resize(num_matrices);
3607  // Matrix zero is a special case; it's the empty matrix.
3610  int32 num_n_values = num_n_values_;
3611  for (int32 m = 1; m < num_matrices; m++) {
3612  const NnetComputation::MatrixDebugInfo &info_in =
3614  NnetComputation::MatrixDebugInfo &info_out =
3616  info_out.is_deriv = info_in.is_deriv;
3617  int32 num_rows_in = computation_.matrices[m].num_rows,
3618  num_rows_out = expanded_computation_->matrices[m].num_rows;
3619  KALDI_ASSERT(num_rows_in == info_in.cindexes.size());
3620  info_out.cindexes.resize(num_rows_out);
3621  const Cindex *cindexes_in = &(info_in.cindexes[0]);
3622  Cindex *cindexes_out = &(info_out.cindexes[0]);
3623  for (int32 r = 0; r < num_rows_in; r++) {
3624  if (info_in.cindexes[r].second.n == 0) {
3625  int32 new_r = GetNewMatrixLocationInfo(m, r),
3626  n_stride = n_stride_[m];
3627  for (int32 n = 0; n < num_n_values; n++) {
3628  int32 r_out = new_r + n * n_stride;
3629  cindexes_out[r_out] = cindexes_in[r];
3630  cindexes_out[r_out].second.n = n;
3631  }
3632  }
3633  }
3634  }
3635 }
std::vector< MatrixDebugInfo > matrix_debug_info
kaldi::int32 int32
std::vector< MatrixInfo > matrices
int32 GetNewMatrixLocationInfo(int32 old_matrix_index, int32 old_row_index) const
This function is used in mapping row-indexes into matrices, from the old to the new computation...
std::pair< int32, Index > Cindex
Definition: nnet-common.h:115
struct rnnlm::@11::@12 n
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ComputeMatrixInfo()

void ComputeMatrixInfo ( )
private

Definition at line 3588 of file nnet-optimize-utils.cc.

References DerivativeTimeLimiter::computation_, and NnetComputation::matrices.

3588  {
3589  int32 num_matrices = computation_.matrices.size();
3590  expanded_computation_->matrices.resize(num_matrices);
3591  // Matrix zero is a special case; it's the empty matrix.
3593  int32 old_num_n_values = 2,
3594  new_num_n_values = num_n_values_;
3595  for (int32 m = 1; m < num_matrices; m++) {
3597  expanded_computation_->matrices[m].num_rows =
3598  (computation_.matrices[m].num_rows / old_num_n_values) *
3599  new_num_n_values;
3600  }
3601 }
kaldi::int32 int32
std::vector< MatrixInfo > matrices

◆ ComputePrecomputedIndexes()

void ComputePrecomputedIndexes ( )
private

Definition at line 3676 of file nnet-optimize-utils.cc.

References NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::command_type, NnetComputation::commands, NnetComputation::component_precomputed_indexes, DerivativeTimeLimiter::computation_, NnetComputation::PrecomputedIndexesInfo::data, Nnet::GetComponent(), NnetComputation::PrecomputedIndexesInfo::input_indexes, KALDI_ASSERT, kaldi::nnet3::kBackprop, kaldi::nnet3::kBackpropNoModelUpdate, kaldi::nnet3::kPropagate, DerivativeTimeLimiter::nnet_, NnetComputation::PrecomputedIndexesInfo::output_indexes, and Component::PrecomputeIndexes().

3676  {
3677  // for each element of 'component_precomputed_indexes',
3678  // we will try to work out the command-index of the associated
3679  // Propagate() command and of the associated Backprop() command,
3680  // if it exists.
3681  // We expect that each such element will be associated with
3682  // exactly one Propagate() command and at most one Backprop() command.
3683  int32 num_commands = computation_.commands.size(),
3684  num_precomputed_indexes = computation_.component_precomputed_indexes.size();
3685 
3686  std::vector<bool> need_backprop(num_precomputed_indexes, false);
3687 
3688  std::vector<int32> component_index(num_precomputed_indexes, -1);
3689 
3690  for (int32 command_index = 0; command_index < num_commands; command_index++) {
3691  const NnetComputation::Command &c = computation_.commands[command_index];
3692 
3693  if (c.command_type == kPropagate && c.arg2 > 0) {
3694  KALDI_ASSERT(c.arg2 < num_precomputed_indexes);
3695  component_index[c.arg2] = c.arg1;
3696  }
3697  if ((c.command_type == kBackprop ||
3698  c.command_type == kBackpropNoModelUpdate) && c.arg2 > 0) {
3699  KALDI_ASSERT(c.arg2 < num_precomputed_indexes);
3700  need_backprop[c.arg2] = true;
3701  }
3702  }
3703 
3704  for (size_t p = 1;
3706  ++p)
3710  num_precomputed_indexes);
3711 
3712  for (int32 p = 1; p < num_precomputed_indexes; ++p) {
3713  const NnetComputation::PrecomputedIndexesInfo &old_info =
3715  NnetComputation::PrecomputedIndexesInfo &new_info =
3717  KALDI_ASSERT(!old_info.input_indexes.empty() &&
3718  !old_info.output_indexes.empty() &&
3719  "Input/output indexes not present in precomputed info of "
3720  "computation to be expanded.");
3721  // note: we could place these expanded indexes into 'new_info.input_indexes'
3722  // and 'new_info.output_indexes', but we actually don't need to keep them
3723  // there, because they are only required to be kept in computations where
3724  // the n indexes consist of the set (0, 1), and the computation we're
3725  // creating has more distinct n indexes than that.
3726  std::vector<Index> input_indexes, output_indexes;
3727  ExpandIndexes(old_info.input_indexes, &input_indexes);
3728  ExpandIndexes(old_info.output_indexes, &output_indexes);
3729  KALDI_ASSERT(component_index[p] >= 0);
3730  const Component *component = nnet_.GetComponent(component_index[p]);
3731  ComponentPrecomputedIndexes *expanded_precomputed_indexes =
3732  component->PrecomputeIndexes(misc_info_, input_indexes,
3733  output_indexes, need_backprop[p]);
3734  // this object should not be null because it was not NULL the
3735  // last time we generated it from the same component, for the
3736  // same computation.
3737  KALDI_ASSERT(expanded_precomputed_indexes != NULL);
3738  new_info.data = expanded_precomputed_indexes;
3739  }
3740 }
const MiscComputationInfo & misc_info_
kaldi::int32 int32
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
std::vector< Command > commands
void ExpandIndexes(const std::vector< Index > &indexes, std::vector< Index > *indexes_expanded) const
Component * GetComponent(int32 c)
Return component indexed c. Not a copy; not owned by caller.
Definition: nnet-nnet.cc:150
std::vector< PrecomputedIndexesInfo > component_precomputed_indexes
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ComputeSubmatrixInfo()

void ComputeSubmatrixInfo ( )
private

Definition at line 3637 of file nnet-optimize-utils.cc.

References NnetComputation::MatrixDebugInfo::cindexes, NnetComputation::SubMatrixInfo::col_offset, DerivativeTimeLimiter::computation_, NnetComputation::GetSubmatrixStrings(), KALDI_ERR, NnetComputation::matrix_debug_info, NnetComputation::SubMatrixInfo::matrix_index, DerivativeTimeLimiter::nnet_, NnetComputation::SubMatrixInfo::num_cols, NnetComputation::SubMatrixInfo::num_rows, NnetComputation::Print(), NnetComputation::SubMatrixInfo::row_offset, and NnetComputation::submatrices.

3637  {
3638  int32 num_submatrices = computation_.submatrices.size();
3639  expanded_computation_->submatrices.resize(num_submatrices);
3640  // Sub-matrix zero is a special case; it's the empty submatrix.
3642  for (int32 s = 1; s < num_submatrices; s++) {
3643  const NnetComputation::SubMatrixInfo &info_in = computation_.submatrices[s];
3644  int32 m = info_in.matrix_index;
3645  const NnetComputation::MatrixDebugInfo &debug_info_in =
3647 
3648  // we may need to change the row_offset and num_rows.
3649  int32 first_row_in = info_in.row_offset,
3650  last_row_in = first_row_in + info_in.num_rows - 1;
3651  if (!(debug_info_in.cindexes[first_row_in].second.n == 0 &&
3652  debug_info_in.cindexes[last_row_in].second.n == 1)) {
3653  std::ostringstream computation_ss;
3654  std::vector<std::string> submat_strings;
3655  computation_.GetSubmatrixStrings(nnet_, &submat_strings);
3656  computation_.Print(computation_ss, nnet_);
3657  KALDI_ERR << "Submatrix s" << s << " = " << submat_strings[s]
3658  << " has strange dimensions. Computation is: "
3659  << computation_ss.str();
3660  }
3661 
3662  int32 first_row_out = GetNewMatrixLocationInfo(m, first_row_in),
3663  last_row_out = GetNewMatrixLocationInfo(m, last_row_in),
3664  new_num_rows = (last_row_out + 1 - first_row_out);
3665 
3666  NnetComputation::SubMatrixInfo &info_out =
3668  info_out.matrix_index = m;
3669  info_out.row_offset = first_row_out;
3670  info_out.num_rows = new_num_rows;
3671  info_out.col_offset = info_in.col_offset;
3672  info_out.num_cols = info_in.num_cols;
3673  }
3674 }
std::vector< MatrixDebugInfo > matrix_debug_info
void Print(std::ostream &os, const Nnet &nnet) const
kaldi::int32 int32
int32 GetNewMatrixLocationInfo(int32 old_matrix_index, int32 old_row_index) const
This function is used in mapping row-indexes into matrices, from the old to the new computation...
std::vector< SubMatrixInfo > submatrices
#define KALDI_ERR
Definition: kaldi-error.h:147
void GetSubmatrixStrings(const Nnet &nnet, std::vector< std::string > *submat_strings) const

◆ EnsureDebugInfoExists()

void EnsureDebugInfoExists ( int32  submatrix_index)
private

◆ Expand()

void Expand ( )

Definition at line 3573 of file nnet-optimize-utils.cc.

References DerivativeTimeLimiter::computation_, and NnetComputation::need_model_derivative.

Referenced by kaldi::nnet3::ExpandComputation().

3573  {
3574  InitStrideInfo();
3576  if (need_debug_info_)
3577  ComputeDebugInfo();
3578  else
3582  ComputeCommands();
3583 
3586 }
std::vector< MatrixDebugInfo > matrix_debug_info

◆ ExpandIndexes()

void ExpandIndexes ( const std::vector< Index > &  indexes,
std::vector< Index > *  indexes_expanded 
) const
private

Definition at line 3794 of file nnet-optimize-utils.cc.

References kaldi::nnet3::ConvertNumNValues(), kaldi::nnet3::FindNStride(), and KALDI_ASSERT.

3796  {
3797  bool full_check = false;
3798  int32 n_stride = FindNStride(indexes, full_check);
3799  KALDI_ASSERT(n_stride > 0);
3800  ConvertNumNValues(n_stride, 2, num_n_values_,
3801  indexes, indexes_expanded);
3802 }
static void ConvertNumNValues(int32 n_stride, int32 old_N, int32 new_N, const std::vector< Index > &indexes_in, std::vector< Index > *indexes_out)
static int32 FindNStride(const std::vector< Index > &indexes, bool full_check)
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ExpandRowRangesCommand()

void ExpandRowRangesCommand ( const NnetComputation::Command c_in,
NnetComputation::Command c_out 
)
private

Definition at line 3426 of file nnet-optimize-utils.cc.

References NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::arg3, DerivativeTimeLimiter::computation_, NnetComputation::indexes_ranges, KALDI_ASSERT, rnnlm::n, and NnetComputation::submatrices.

3428  {
3429  // we need to expand the pairs of row-indexes in c_in.arg2, and put the index
3430  // of the resulting vector<int> in expanded_computation_->indexes_ranges, in
3431  // 'c_out->arg2'.
3432 
3433  int32 s1 = c_in.arg1, s2 = c_in.arg2,
3434  num_rows_old = computation_.submatrices[s1].num_rows,
3435  num_rows_new = expanded_computation_->submatrices[s1].num_rows;
3436  KALDI_ASSERT(static_cast<size_t>(c_in.arg3) <
3437  computation_.indexes_ranges.size());
3438  int32 num_n_values = num_n_values_;
3439 
3440  int32 old_arg3 = c_out->arg3;
3441  c_out->arg3 = expanded_computation_->indexes_ranges.size();
3443  std::vector<std::pair<int32, int32> >());
3444  std::vector<std::pair<int32, int32> > &new_indexes_ranges =
3446  const std::vector<std::pair<int32, int32> > &old_indexes_ranges =
3447  computation_.indexes_ranges[old_arg3];
3448  // old_indexes_ranges is a vector that has the same size as the num-rows of
3449  // submatrix s1. It contains pairs that are either two copies of the same
3450  // value (in practice the pair (-1, -1)), or pairs (begin-row-index,
3451  // end-row-index) representing the (begin,end) of a range in submatrix s2.
3452  // Note: end-row-index is one past the end of the range, as for C++ iterators.
3453 
3454  KALDI_ASSERT(static_cast<int32>(old_indexes_ranges.size()) == num_rows_old);
3455 
3456  new_indexes_ranges.resize(num_rows_new,
3457  std::pair<int32,int32>(-1, -1));
3458 
3459  for (int32 i1 = 0; i1 < num_rows_old; i1++) {
3460  int32 new_i1_n0, n_stride1;
3461  if (GetNewSubmatLocationInfo(s1, i1, &new_i1_n0, &n_stride1)) {
3462  // GetNewSubmatLocationInfo() returns true if this corresponds to
3463  // a Cindex with n == 0.
3464  int32 i2_begin = old_indexes_ranges[i1].first,
3465  i2_end = old_indexes_ranges[i1].second;
3466  if (i2_end == i2_begin)
3467  continue; // (-1, -1) pair, meaning an empty range.
3468  // 'new_indexes_ranges' is filled with (-1, -1) pairs as a
3469  // default so we don't have to do anything for these
3470  // elements.
3471  int32 i2_last = i2_end - 1;
3472  int32 new_i2_n0_begin, new_i2_n0_last,
3473  n_stride2; // only 1 stride variable; both calls will output
3474  // the same value.
3475 
3476  bool ans1 = GetNewSubmatLocationInfo(s2, i2_begin, &new_i2_n0_begin,
3477  &n_stride2),
3478  ans2 = GetNewSubmatLocationInfo(s2, i2_last, &new_i2_n0_last,
3479  &n_stride2);
3480  KALDI_ASSERT(ans1 && ans2 && new_i2_n0_last >= new_i2_n0_begin &&
3481  new_i2_n0_begin >= 0 && n_stride1 > 0 && n_stride2 > 0);
3482  // source should also be for n==0, because we don't (or at least
3483  // shouldn't) create computations that mix up the 'n' values
3484 
3485 
3486  int32 new_i1 = new_i1_n0,
3487  new_i2_begin = new_i2_n0_begin,
3488  new_i2_end = new_i2_n0_last + 1;
3489  for (int32 n = 0; n < num_n_values;
3490  n++, new_i1 += n_stride1, new_i2_begin += n_stride2,
3491  new_i2_end += n_stride2) {
3492  new_indexes_ranges[new_i1].first = new_i2_begin;
3493  new_indexes_ranges[new_i1].second = new_i2_end;
3494  }
3495  }
3496  }
3497 }
kaldi::int32 int32
std::vector< SubMatrixInfo > submatrices
struct rnnlm::@11::@12 n
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool GetNewSubmatLocationInfo(int32 submat_index, int32 old_row_index, int32 *new_row_index, int32 *n_stride) const
std::vector< std::vector< std::pair< int32, int32 > > > indexes_ranges

◆ ExpandRowsCommand()

void ExpandRowsCommand ( const NnetComputation::Command c_in,
NnetComputation::Command c_out 
)
private

Definition at line 3298 of file nnet-optimize-utils.cc.

References NnetComputation::Command::alpha, NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::arg3, DerivativeTimeLimiter::computation_, NnetComputation::indexes, KALDI_ASSERT, rnnlm::n, and NnetComputation::submatrices.

3300  {
3301  // we need to expand the row-indexes in c_in.arg3, and put the index of the
3302  // resulting vector<int> in expanded_computation_->indexes, in 'c_out->arg3'.
3303 
3304  int32 s1 = c_in.arg1, s2 = c_in.arg2;
3305 
3306  // The command that gets called is something like
3307  // submat1.AddRows(submat2, indexes) if submat1 is the submatrix referred to in
3308  // 's1' and submat2 is the submatrix referred to in 's2'.
3309  // 'indexes' has the same size as the num-rows of submat1, and the values
3310  // in the vector are row-indexes into s2.
3311  int32 old_arg3 = c_out->arg3;
3312  c_out->arg3 = expanded_computation_->indexes.size();
3313  c_out->alpha = c_in.alpha;
3314  expanded_computation_->indexes.push_back(std::vector<int32>());
3315  std::vector<int32> &new_indexes = expanded_computation_->indexes.back();
3316  const std::vector<int32> &old_indexes = computation_.indexes[old_arg3];
3317 
3318  int32 old_size = old_indexes.size(),
3319  num_n_values = num_n_values_,
3320  new_s1_size = expanded_computation_->submatrices[s1].num_rows,
3321  new_s2_size = expanded_computation_->submatrices[s2].num_rows;
3322 
3323  KALDI_ASSERT(old_size == computation_.submatrices[s1].num_rows);
3324 
3325  new_indexes.resize(new_s1_size, -1);
3326 
3327 
3328  // A note on the variable names: i1 and i2 are indexes into the destination
3329  // submatrix and the source submatrix respectively, of the CopyRows or AddRows
3330  // command.
3331  // "n0" in the variable name means that this corresponds to an Index with n==0.
3332  // things without "new" in the name refer to the old computation; things with
3333  // "new" in the name refer to the computation that we are generating.
3334  for (int32 i1 = 0; i1 < old_size; i1++) {
3335  int32 new_i1_n0, n_stride1;
3336  if (GetNewSubmatLocationInfo(s1, i1, &new_i1_n0, &n_stride1)) {
3337  // GetNewSubmatLocationInfo() returns true if this corresponds to
3338  // a Cindex with n == 0.
3339  int32 i2 = old_indexes[i1]; // note: i2 is the row index into submatrix s2.
3340  int32 new_i2_n0, n_stride2;
3341  if (i2 < 0) { // if i2 is -1, we'll just leave any relevant positions in
3342  // 'new_indexes' with -1's in them.
3343  continue;
3344  } else {
3345  bool ans = GetNewSubmatLocationInfo(s2, i2, &new_i2_n0, &n_stride2);
3346  KALDI_ASSERT(ans); // source should also be for n==0, because we don't
3347  // (or at least shouldn't) create computations that
3348  // mix up the 'n' values
3349 
3350  int32 new_i1 = new_i1_n0, new_i2 = new_i2_n0;
3351  for (int32 n = 0; n < num_n_values;
3352  ++n, new_i1 += n_stride1, new_i2 += n_stride2) {
3353  KALDI_ASSERT(new_i1 < new_s1_size && new_i2 < new_s2_size);
3354  new_indexes[new_i1] = new_i2;
3355  }
3356  }
3357  }
3358  }
3359 }
kaldi::int32 int32
std::vector< SubMatrixInfo > submatrices
struct rnnlm::@11::@12 n
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< std::vector< int32 > > indexes
bool GetNewSubmatLocationInfo(int32 submat_index, int32 old_row_index, int32 *new_row_index, int32 *n_stride) const

◆ ExpandRowsMultiCommand()

void ExpandRowsMultiCommand ( const NnetComputation::Command c_in,
NnetComputation::Command c_out 
)
private

Definition at line 3361 of file nnet-optimize-utils.cc.

References NnetComputation::Command::arg1, NnetComputation::Command::arg2, DerivativeTimeLimiter::computation_, NnetComputation::indexes_multi, KALDI_ASSERT, rnnlm::n, and NnetComputation::submatrices.

3363  {
3364  // we need to expand the (submatrix,row)-index pairs in c_in.arg2, and put the
3365  // index of the resulting vector<int> in expanded_computation_->indexes_multi,
3366  // in 'c_out->arg2'.
3367 
3368  int32 s1 = c_in.arg1,
3369  num_rows_old = computation_.submatrices[s1].num_rows,
3370  num_rows_new = expanded_computation_->submatrices[s1].num_rows;
3371 
3372  KALDI_ASSERT(num_rows_old % 2 == 0);
3373  int32 num_n_values = num_n_values_;
3374 
3375  int32 old_arg2 = c_out->arg2;
3376  c_out->arg2 = expanded_computation_->indexes_multi.size();
3378  std::vector<std::pair<int32, int32> >());
3379  std::vector<std::pair<int32, int32> > &new_indexes_multi =
3381  const std::vector<std::pair<int32, int32> > &old_indexes_multi =
3382  computation_.indexes_multi[old_arg2];
3383  // old_indexes_multi is a vector that has the same size as the num-rows
3384  // of submatrix s1. It contains pairs that are either (-1, -1), or
3385  // pairs (submatrix-index, row-index) referring to other submatrices
3386  // in the computation.
3387 
3388  KALDI_ASSERT(static_cast<int32>(old_indexes_multi.size()) == num_rows_old);
3389 
3390 
3391  new_indexes_multi.resize(num_rows_new,
3392  std::pair<int32,int32>(-1, -1));
3393 
3394  for (int32 i1 = 0; i1 < num_rows_old; i1++) {
3395  int32 new_i1_n0, n_stride1;
3396  if (GetNewSubmatLocationInfo(s1, i1, &new_i1_n0, &n_stride1)) {
3397  // GetNewSubmatLocationInfo() returns true if this corresponds to
3398  // a Cindex with n == 0.
3399  int32 s2 = old_indexes_multi[i1].first,
3400  i2 = old_indexes_multi[i1].second;
3401  int32 new_i2_n0, n_stride2;
3402  if (s2 < 0) { // if s2 is -1, we don't have to do anything... we'd have
3403  // to fill any relevant positions in 'new_indexes_multi'
3404  // with (-1,-1)'s, but it's filled with that by default.
3405  continue;
3406  } else {
3407  bool ans = GetNewSubmatLocationInfo(s2, i2, &new_i2_n0, &n_stride2);
3408  KALDI_ASSERT(ans); // source should also be for n==0, because we don't
3409  // (or at least shouldn't) create computations that
3410  // mix up the 'n' values
3411 
3412  int32 new_i1 = new_i1_n0, new_i2 = new_i2_n0;
3413 
3414  for (int32 n = 0; n < num_n_values;
3415  n++, new_i1 += n_stride1, new_i2 += n_stride2) {
3416  new_indexes_multi[new_i1].first = s2;
3417  new_indexes_multi[new_i1].second = new_i2;
3418  }
3419  }
3420  }
3421  }
3422 }
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > indexes_multi
std::vector< SubMatrixInfo > submatrices
struct rnnlm::@11::@12 n
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool GetNewSubmatLocationInfo(int32 submat_index, int32 old_row_index, int32 *new_row_index, int32 *n_stride) const

◆ GetNewMatrixLocationInfo()

int32 GetNewMatrixLocationInfo ( int32  old_matrix_index,
int32  old_row_index 
) const
private

This function is used in mapping row-indexes into matrices, from the old to the new computation.

Parameters
[in]matrix_indexThe matrix-index > 0, for which we are mapping row-indexes. The matrix-indexes are the same in the old and new computations.
[in]old_row_indexThe old row-index into the matrix.
Returns
This function returns the row-index where the cindex referred to in 'old_matrix_index' will reside in the new, expanded computation, WITH THE CAVEAT THAT if the old cindex had n == 1, we'll output the location of the cindex with n == num_n_values_ - 1. This happens to be what we want (it maps the last n value on the input to the last n value on the output.

Definition at line 3761 of file nnet-optimize-utils.cc.

References DerivativeTimeLimiter::computation_, KALDI_ASSERT, and NnetComputation::matrix_debug_info.

3762  {
3763  // to understand 'block_size', read the comment for FindNStride().
3764  int32 n_stride = n_stride_[matrix_index],
3765  old_num_n_values = 2, new_num_n_values = num_n_values_,
3766  old_block_size = old_num_n_values * n_stride,
3767  new_block_size = new_num_n_values * n_stride,
3768  block_index = old_row_index / old_block_size,
3769  offset_within_block = old_row_index % old_block_size;
3770 
3771  // within each block, we can show, given our assumptions, that
3772  // we must first have a sub-block of 'n_stride' values all with
3773  // n == 0, then another sub-clock of 'n_stride' values all with
3774  // n == 1, and so on. [except there is no 'and so on' for the
3775  // input computation, where we expect the 'n' values to be the
3776  // set {0, 1}.]
3777  int32 old_n_value = offset_within_block / n_stride,
3778  index_within_subblock = offset_within_block % n_stride;
3779  const std::vector<Cindex> &cindexes =
3780  computation_.matrix_debug_info[matrix_index].cindexes;
3781  KALDI_ASSERT(old_n_value == cindexes[old_row_index].second.n &&
3782  (old_n_value == 0 || old_n_value == 1));
3783  // Search for CAVEAT in the comment for this function to see what this is
3784  // about. Mapping old_n_value == 1 -> new_n_value == new_num_n_values - 1
3785  // just happens to be useful for the way we use this function... it maps the
3786  // end of an old submatrix to the end of a new submatrix.
3787  int32 new_n_value = (old_n_value == 0 ? 0 : new_num_n_values - 1);
3788 
3789  return block_index * new_block_size + index_within_subblock +
3790  new_n_value * n_stride;
3791 }
std::vector< MatrixDebugInfo > matrix_debug_info
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetNewSubmatLocationInfo()

bool GetNewSubmatLocationInfo ( int32  submat_index,
int32  old_row_index,
int32 new_row_index,
int32 n_stride 
) const
private

Definition at line 3743 of file nnet-optimize-utils.cc.

References NnetComputation::MatrixDebugInfo::cindexes, DerivativeTimeLimiter::computation_, NnetComputation::matrix_debug_info, and NnetComputation::submatrices.

3745  {
3746  int32 matrix_index = computation_.submatrices[submat_index].matrix_index,
3747  old_row_offset = computation_.submatrices[submat_index].row_offset,
3748  new_row_offset = expanded_computation_->submatrices[submat_index].row_offset;
3749 
3750  const NnetComputation::MatrixDebugInfo &debug_info_in =
3751  computation_.matrix_debug_info[matrix_index];
3752  if (debug_info_in.cindexes[old_row_index + old_row_offset].second.n != 0)
3753  return false;
3754  *new_row_index = (GetNewMatrixLocationInfo(matrix_index,
3755  old_row_index + old_row_offset) -
3756  new_row_offset);
3757  *n_stride = n_stride_[matrix_index];
3758  return true;
3759 }
std::vector< MatrixDebugInfo > matrix_debug_info
kaldi::int32 int32
int32 GetNewMatrixLocationInfo(int32 old_matrix_index, int32 old_row_index) const
This function is used in mapping row-indexes into matrices, from the old to the new computation...
std::vector< SubMatrixInfo > submatrices

◆ InitStrideInfo()

void InitStrideInfo ( )
private

Definition at line 3548 of file nnet-optimize-utils.cc.

References NnetComputation::MatrixDebugInfo::cindexes, DerivativeTimeLimiter::computation_, kaldi::nnet3::FindNStride(), KALDI_ASSERT, KALDI_ERR, NnetComputation::matrices, and NnetComputation::matrix_debug_info.

3548  {
3549  // note: the zeroth matrix is not a real matrix, it's the empty matrix.
3550  int32 num_matrices = computation_.matrices.size();
3551  n_stride_.resize(num_matrices);
3552  n_stride_[0] = 0;
3553 
3554  // the input computation to class ComputationExpander is required to
3555  // have its debug info set up.
3557  for (int32 m = 1; m < num_matrices; m++) {
3558  int32 num_rows = computation_.matrices[m].num_rows;
3559  const NnetComputation::MatrixDebugInfo &debug_info = computation_.matrix_debug_info[m];
3560  KALDI_ASSERT(debug_info.cindexes.size() == num_rows);
3561  bool full_check = true; // TODO: eventually change this to false.
3562  int32 n_stride = FindNStride(debug_info.cindexes, full_check);
3563  if (n_stride == 0) {
3564  KALDI_ERR << "Problem encountered in 'shortcut' compilation: the computation "
3565  << "does not have the expected structure. Try compiling with "
3566  << "--use-shortcut=false.";
3567  }
3568  n_stride_[m] = n_stride;
3569  }
3570 }
std::vector< MatrixDebugInfo > matrix_debug_info
static int32 FindNStride(const std::vector< Index > &indexes, bool full_check)
kaldi::int32 int32
std::vector< MatrixInfo > matrices
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Data Documentation

◆ computation_

const NnetComputation& computation_
private

Definition at line 3290 of file nnet-optimize-utils.cc.

◆ expanded_computation_

NnetComputation* expanded_computation_
private

Definition at line 3293 of file nnet-optimize-utils.cc.

◆ misc_info_

const MiscComputationInfo& misc_info_
private

Definition at line 3289 of file nnet-optimize-utils.cc.

◆ n_stride_

std::vector<int32> n_stride_
private

Definition at line 3286 of file nnet-optimize-utils.cc.

◆ need_debug_info_

bool need_debug_info_
private

Definition at line 3291 of file nnet-optimize-utils.cc.

◆ nnet_

const Nnet& nnet_
private

Definition at line 3288 of file nnet-optimize-utils.cc.

◆ num_n_values_

int32 num_n_values_
private

Definition at line 3292 of file nnet-optimize-utils.cc.


The documentation for this class was generated from the following file: