RowOpsSplitter Class Reference
Collaboration diagram for RowOpsSplitter:

Classes

struct  MultiIndexSplitInfo
 
struct  SingleSplitInfo
 

Public Member Functions

 RowOpsSplitter (NnetComputation *computation)
 
bool Split ()
 

Private Member Functions

bool SplitIndexes ()
 
bool SplitCommands ()
 
bool SplitCommand (int32 command_index)
 
bool GetSplitInfo (std::vector< std::pair< int32, int32 > >::const_iterator begin, std::vector< std::pair< int32, int32 > >::const_iterator end, SingleSplitInfo *info)
 

Private Attributes

NnetComputationcomputation_
 
std::vector< MultiIndexSplitInfosplit_info_
 
std::vector< std::pair< int32, NnetComputation::Command > > new_commands_
 

Detailed Description

Definition at line 2578 of file nnet-optimize-utils.cc.

Constructor & Destructor Documentation

◆ RowOpsSplitter()

RowOpsSplitter ( NnetComputation computation)
inline

Definition at line 2580 of file nnet-optimize-utils.cc.

2580 : computation_(computation) { }

Member Function Documentation

◆ GetSplitInfo()

bool GetSplitInfo ( std::vector< std::pair< int32, int32 > >::const_iterator  begin,
std::vector< std::pair< int32, int32 > >::const_iterator  end,
SingleSplitInfo info 
)
private

Definition at line 2686 of file nnet-optimize-utils.cc.

References RowOpsSplitter::SingleSplitInfo::first_value, rnnlm::i, KALDI_ASSERT, RowOpsSplitter::SingleSplitInfo::min_second_value, RowOpsSplitter::SingleSplitInfo::second_value_offsets, RowOpsSplitter::SingleSplitInfo::second_value_range, and RowOpsSplitter::SingleSplitInfo::size.

2689  {
2690  // max_size_ratio must be > 1.0, and could in principle be a float. It is
2691  // there to prevent us from making changes to the computation which would end
2692  // up wastefully launching too many kernels that would do nothing.
2693  const int32 max_size_ratio = 2;
2694 
2695  int32 size = end - begin;
2696  KALDI_ASSERT(size != 0);
2697  int32 first = begin->first;
2698  if (first < 0)
2699  return false;
2700  info->size = size;
2701  info->first_value = first;
2702  int32 initial_second_value = begin->second,
2703  min_second_value = initial_second_value,
2704  max_second_value = initial_second_value;
2705  info->second_value_offsets.resize(size);
2706  bool is_consecutive = true;
2707  for (int32 i = 0; i < size; i++) {
2708  int32 second = begin[i].second;
2709  if (begin[i].first != first || second < 0) return false;
2710  info->second_value_offsets[i] = second;
2711  if (second != initial_second_value + i)
2712  is_consecutive = false;
2713  if (second < min_second_value) min_second_value = second;
2714  if (second > max_second_value) max_second_value = second;
2715  }
2716  info->min_second_value = min_second_value;
2717  info->second_value_range = max_second_value + 1 - min_second_value;
2718  if (info->second_value_range > size * max_size_ratio)
2719  return false;
2720  if (is_consecutive) {
2721  info->second_value_offsets.clear();
2722  } else {
2723  for (int32 i = 0; i < size; i++)
2724  info->second_value_offsets[i] -= min_second_value;
2725  }
2726  return true;
2727 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Split()

bool Split ( )
inline

Definition at line 2584 of file nnet-optimize-utils.cc.

Referenced by kaldi::nnet3::SplitRowOps().

◆ SplitCommand()

bool SplitCommand ( int32  command_index)
private

Definition at line 2780 of file nnet-optimize-utils.cc.

References NnetComputation::Command::alpha, NnetComputation::Command::arg1, NnetComputation::Command::arg2, NnetComputation::Command::arg3, NnetComputation::Command::command_type, NnetComputation::commands, DerivativeTimeLimiter::computation_, RowOpsSplitter::SingleSplitInfo::first_value, rnnlm::i, NnetComputation::indexes, kaldi::nnet3::kAddRows, kaldi::nnet3::kAddRowsMulti, kaldi::nnet3::kAddToRowsMulti, KALDI_ASSERT, KALDI_ERR, kaldi::nnet3::kCopyRows, kaldi::nnet3::kCopyRowsMulti, kaldi::nnet3::kCopyToRowsMulti, kaldi::nnet3::kMatrixAdd, kaldi::nnet3::kMatrixCopy, RowOpsSplitter::SingleSplitInfo::min_second_value, NnetComputation::NewSubMatrix(), RowOpsSplitter::SingleSplitInfo::offset, RowOpsSplitter::SingleSplitInfo::second_value_offsets, RowOpsSplitter::SingleSplitInfo::second_value_range, RowOpsSplitter::SingleSplitInfo::size, RowOpsSplitter::MultiIndexSplitInfo::splits, and kaldi::swap().

2780  {
2781  NnetComputation::Command &command = computation_->commands[c];
2782  CommandType command_type = command.command_type;
2783  // For commands that are not of the following four types, return false: we
2784  // won't be changing these commands.
2785  switch (command_type) {
2786  case kAddRowsMulti: case kCopyRowsMulti:
2787  case kAddToRowsMulti: case kCopyToRowsMulti: break;
2788  default: return false;
2789  }
2790  int32 indexes_multi_index = command.arg2;
2791  KALDI_ASSERT(indexes_multi_index <
2792  static_cast<int32>(split_info_.size()));
2793  const MultiIndexSplitInfo &split_info = split_info_[indexes_multi_index];
2794  if (split_info.splits.empty())
2795  return false; // these indexes couldn't be split: e.g. they contained more
2796  // than two distinct .first elements, or there were other
2797  // reasons.
2798 
2799  // we'll be splitting the command into either one or two pieces.
2800  std::vector<NnetComputation::Command> split_commands(
2801  split_info.splits.size());
2802  for (size_t i = 0; i < split_info.splits.size(); i++) {
2803  const SingleSplitInfo &split = split_info.splits[i];
2804  NnetComputation::Command &command_out = split_commands[i];
2805  command_out.alpha = command.alpha;
2806  command_out.arg1 = computation_->NewSubMatrix(
2807  command.arg1, split.offset, split.size, 0, -1);
2808  command_out.arg2 = computation_->NewSubMatrix(
2809  split.first_value, split.min_second_value,
2810  split.second_value_range, 0, -1);
2811 
2812  if (split.second_value_offsets.empty()) {
2813  // The .second elements are consecutive.
2814  switch (command_type) {
2815  case kAddRowsMulti:
2816  command_out.command_type = kMatrixAdd;
2817  break;
2818  case kCopyRowsMulti:
2819  command_out.command_type = kMatrixCopy;
2820  break;
2821  case kAddToRowsMulti:
2822  command_out.command_type = kMatrixAdd;
2823  std::swap(command_out.arg1, command_out.arg2);
2824  break;
2825  case kCopyToRowsMulti:
2826  command_out.command_type = kMatrixCopy;
2827  std::swap(command_out.arg1, command_out.arg2);
2828  break;
2829  default: // will never be reached.
2830  break;
2831  }
2832  } else {
2833  // Indexes are not consecutive: it needs to be a kAddRows or kCopyRows
2834  // command.
2835  command_out.arg3 = computation_->indexes.size();
2836  switch (command_type) {
2837  case kAddRowsMulti: case kCopyRowsMulti: {
2838  command_out.command_type = (command_type == kAddRowsMulti ?
2839  kAddRows : kCopyRows);
2840  computation_->indexes.push_back(split.second_value_offsets);
2841  break;
2842  }
2843  case kCopyToRowsMulti: {
2844  // We can't operate on this command because of what would happen
2845  // with values of 'indexes' (see the variable in the block for
2846  // kAddToRowsMulti) which were -1. Rows of the output would be
2847  // set to zero, which is not the behavior we want here; we'd want
2848  // them to be unaffected.
2849  return false;
2850  }
2851  case kAddToRowsMulti: {
2852  command_out.command_type = kAddRows;
2853  std::swap(command_out.arg1, command_out.arg2);
2854  // invert the indexes.
2855  std::vector<int32> indexes(split.second_value_range, -1);
2856  for (int32 i = 0; i < split.size; i++) {
2857  // the following assert should always succeed because the
2858  // AddToRowsMulti and CopyToRowsMulti should never have
2859  // duplicate destinations in their indexes.
2860  KALDI_ASSERT(indexes[split.second_value_offsets[i]] >= 0);
2861  indexes[split.second_value_offsets[i]] = i;
2862  }
2863  computation_->indexes.push_back(indexes);
2864  break;
2865  }
2866  default:
2867  KALDI_ERR << "Code error: un-handled case.";
2868  }
2869  }
2870  }
2871  command = split_commands[0];
2872  // note: for now, split_commands.size() will be 1 or 2.
2873  for (size_t i = 1; i < split_commands.size(); i++) {
2874  new_commands_.resize(new_commands_.size() + 1);
2875  // we'll want to insert this command right after command c,
2876  // which is the same as just before command c + 1.
2877  new_commands_.back().first = c + 1;
2878  new_commands_.back().second = split_commands[i];
2879  }
2880  return true; // We made a change.
2881 }
CommandType
CommandType is an enum that describes the category of the command used in the NnetComputation.
std::vector< std::pair< int32, NnetComputation::Command > > new_commands_
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
kaldi::int32 int32
std::vector< Command > commands
int32 NewSubMatrix(int32 base_submatrix, int32 row_offset, int32 num_rows, int32 col_offset, int32 num_cols)
Convenience function used when adding new sub-matrices.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< std::vector< int32 > > indexes
std::vector< MultiIndexSplitInfo > split_info_

◆ SplitCommands()

bool SplitCommands ( )
private

Definition at line 2883 of file nnet-optimize-utils.cc.

References NnetComputation::commands, DerivativeTimeLimiter::computation_, and kaldi::nnet3::InsertCommands().

2883  {
2884  bool ans = false;
2885  int32 num_commands = computation_->commands.size();
2886  for (int32 c = 0; c < num_commands; c++)
2887  if (SplitCommand(c))
2888  ans = true;
2889  if (!new_commands_.empty())
2891  return ans;
2892 }
void InsertCommands(std::vector< std::pair< int32, NnetComputation::Command > > *new_commands, NnetComputation *computation)
Inserts commands into the computation at the requested places.
bool SplitCommand(int32 command_index)
std::vector< std::pair< int32, NnetComputation::Command > > new_commands_
kaldi::int32 int32
std::vector< Command > commands

◆ SplitIndexes()

bool SplitIndexes ( )
private

Definition at line 2730 of file nnet-optimize-utils.cc.

References DerivativeTimeLimiter::computation_, rnnlm::i, NnetComputation::indexes_multi, rnnlm::j, KALDI_ASSERT, and RowOpsSplitter::MultiIndexSplitInfo::splits.

2730  {
2731  bool ans = false;
2732  int32 num_indexes_multi = computation_->indexes_multi.size();
2733  split_info_.resize(num_indexes_multi);
2734  for (int32 i = 0; i < num_indexes_multi; i++) {
2735  const std::vector<std::pair<int32,int32> > &multi_index =
2737  MultiIndexSplitInfo &split_info = split_info_[i];
2738 
2739  int32 num_pairs = multi_index.size();
2740  KALDI_ASSERT(num_pairs > 0);
2741  // 'split_point' will be set to the first index j for which
2742  // multi_index[j-1].first != multi_index[j].first, or -1
2743  // if no such j exists.
2744  int32 split_point = -1, initial_first = multi_index[0].first;
2745  for (int32 j = 1; j < num_pairs; j++) {
2746  if (multi_index[j].first != initial_first) {
2747  split_point = j;
2748  break;
2749  }
2750  }
2751  if (split_point == -1) {
2752  split_info.splits.resize(1);
2753  split_info.splits[0].offset = 0;
2754  if (!GetSplitInfo(multi_index.begin(), multi_index.end(),
2755  &(split_info.splits[0]))) {
2756  split_info.splits.clear();
2757  } else {
2758  ans = true;
2759  }
2760  } else {
2761  split_info.splits.resize(2);
2762  split_info.splits[0].offset = 0;
2763  split_info.splits[1].offset = split_point;
2764 
2765  std::vector<std::pair<int32,int32> >::const_iterator mid_iter =
2766  multi_index.begin() + split_point;
2767  if (!GetSplitInfo(multi_index.begin(), mid_iter,
2768  &(split_info.splits[0])) ||
2769  !GetSplitInfo(mid_iter, multi_index.end(),
2770  &(split_info.splits[1]))) {
2771  split_info.splits.clear();
2772  } else {
2773  ans = true;
2774  }
2775  }
2776  }
2777  return ans;
2778 }
kaldi::int32 int32
std::vector< std::vector< std::pair< int32, int32 > > > indexes_multi
bool GetSplitInfo(std::vector< std::pair< int32, int32 > >::const_iterator begin, std::vector< std::pair< int32, int32 > >::const_iterator end, SingleSplitInfo *info)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< MultiIndexSplitInfo > split_info_

Member Data Documentation

◆ computation_

NnetComputation* computation_
private

Definition at line 2673 of file nnet-optimize-utils.cc.

◆ new_commands_

std::vector<std::pair<int32, NnetComputation::Command> > new_commands_
private

Definition at line 2681 of file nnet-optimize-utils.cc.

◆ split_info_

std::vector<MultiIndexSplitInfo> split_info_
private

Definition at line 2676 of file nnet-optimize-utils.cc.


The documentation for this class was generated from the following file: