LimitRankClass Class Reference
Collaboration diagram for LimitRankClass:

Public Member Functions

 LimitRankClass (const NnetLimitRankOpts &opts, int32 c, Nnet *nnet)
 
void operator() ()
 
int32 GetRetainedDim (int32 rows, int32 cols)
 
 ~LimitRankClass ()
 

Private Attributes

const NnetLimitRankOptsopts_
 
int32 c_
 
Nnetnnet_
 

Detailed Description

Definition at line 26 of file nnet-limit-rank.cc.

Constructor & Destructor Documentation

◆ LimitRankClass()

LimitRankClass ( const NnetLimitRankOpts opts,
int32  c,
Nnet nnet 
)
inline

Definition at line 28 of file nnet-limit-rank.cc.

30  : opts_(opts), c_(c), nnet_(nnet) { }
const NnetLimitRankOpts & opts_

◆ ~LimitRankClass()

~LimitRankClass ( )
inline

Definition at line 91 of file nnet-limit-rank.cc.

91 { }

Member Function Documentation

◆ GetRetainedDim()

int32 GetRetainedDim ( int32  rows,
int32  cols 
)
inline

Definition at line 62 of file nnet-limit-rank.cc.

References KALDI_ASSERT, KALDI_ERR, LimitRankClass::opts_, and NnetLimitRankOpts::parameter_proportion.

Referenced by LimitRankClass::operator()().

62  {
64  KALDI_ERR << "bad --parameter-proportion " << opts_.parameter_proportion;
65  // If we do SVD to dimension d, so that it's U diag(s) V^T where
66  // U is rows * d, s is d, and V is cols * d, then the #params is as follows...
67  // the first column of U has free parameters (#rows - 1) [the -1 is due to
68  // the length constraint]; the second has (#rows - 2) [subtract 1 for the
69  // length constraint and one for orthogonality with the previous row], etc.
70  // Total is params(U) = (rows * d) - ((d(d+1))/2),
71  // params(s) = d,
72  // params(V) = (cols * d) - ((d(d+1))/2),
73  // So total is (rows + cols) * d - d * d .
74  // For example, if d = #rows, this equals (#rows * #cols)
75  // We are solving for:
76  // (rows * cols) * parameter_proportion = (rows + cols) * d - d * d, or
77  // d^2 - d * (rows + cols) + (rows*cols)*parameter_proportion
78  // In quadratic equation
79  // a = 1.0,
80  // b = -(rows + cols)
81  // c = rows * cols * parameter_proportion.
82  // Take smaller solution.
83  BaseFloat a = 1.0, b = -(rows + cols),
84  c = rows * cols * opts_.parameter_proportion;
85  BaseFloat x = (-b - sqrt(b * b - 4 * a * c)) / (2.0 * a);
86  int32 ans = static_cast<int32>(x);
87  KALDI_ASSERT(ans > 0 && ans <= std::min(rows, cols));
88  return ans;
89  }
kaldi::int32 int32
const NnetLimitRankOpts & opts_
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ operator()()

void operator() ( )
inline

Definition at line 31 of file nnet-limit-rank.cc.

References AffineComponent::BiasParams(), LimitRankClass::c_, rnnlm::d, Nnet::GetComponent(), LimitRankClass::GetRetainedDim(), KALDI_ASSERT, KALDI_LOG, kaldi::kCopyData, kaldi::kNoTrans, AffineComponent::LinearParams(), MatrixBase< Real >::MulRowsVec(), LimitRankClass::nnet_, MatrixBase< Real >::NumRows(), Vector< Real >::Resize(), Matrix< Real >::Resize(), AffineComponent::SetParams(), kaldi::SortSvd(), and VectorBase< Real >::Sum().

31  {
32  AffineComponent *ac = dynamic_cast<AffineComponent*>(
33  &(nnet_->GetComponent(c_)));
34  KALDI_ASSERT(ac != NULL);
35 
36  // We'll limit the rank of just the linear part, keeping the bias vector full.
37  Matrix<BaseFloat> M (ac->LinearParams());
38  int32 rows = M.NumRows(), cols = M.NumCols(), rc_min = std::min(rows, cols);
39  Vector<BaseFloat> s(rc_min);
40  Matrix<BaseFloat> U(rows, rc_min), Vt(rc_min, cols);
41  // Do the destructive svd M = U diag(s) V^T. It actually outputs the transpose of V.
42  M.DestructiveSvd(&s, &U, &Vt);
43  SortSvd(&s, &U, &Vt); // Sort the singular values from largest to smallest.
44 
45  int32 d = GetRetainedDim(rows, cols);
46  BaseFloat old_svd_sum = s.Sum();
47  U.Resize(rows, d, kCopyData);
48  s.Resize(d, kCopyData);
49  Vt.Resize(d, cols, kCopyData);
50  BaseFloat new_svd_sum = s.Sum();
51  KALDI_LOG << "For component " << c_ << " of dimension " << rows
52  << " x " << cols << ", reduced rank from "
53  << rc_min << " to " << d << ", SVD sum reduced from "
54  << old_svd_sum << " to " << new_svd_sum;
55  Vt.MulRowsVec(s); // Vt <-- diag(s) Vt.
56  M.AddMatMat(1.0, U, kNoTrans, Vt, kNoTrans, 0.0); // Reconstruct with reduced
57  // rank.
58  Vector<BaseFloat> bias_params(ac->BiasParams());
59  ac->SetParams(bias_params, M);
60  }
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
int32 GetRetainedDim(int32 rows, int32 cols)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
void SortSvd(VectorBase< Real > *s, MatrixBase< Real > *U, MatrixBase< Real > *Vt, bool sort_on_absolute_value)
Function to ensure that SVD is sorted.

Member Data Documentation

◆ c_

int32 c_
private

Definition at line 94 of file nnet-limit-rank.cc.

Referenced by LimitRankClass::operator()().

◆ nnet_

Nnet* nnet_
private

Definition at line 95 of file nnet-limit-rank.cc.

Referenced by LimitRankClass::operator()().

◆ opts_

const NnetLimitRankOpts& opts_
private

Definition at line 93 of file nnet-limit-rank.cc.

Referenced by LimitRankClass::GetRetainedDim().


The documentation for this class was generated from the following file: