NnetLdaStatsAccumulator Class Reference
Collaboration diagram for NnetLdaStatsAccumulator:

Public Member Functions

 NnetLdaStatsAccumulator (BaseFloat rand_prune, const Nnet &nnet)
 
void AccStats (const NnetExample &eg)
 
void WriteStats (const std::string &stats_wxfilename, bool binary)
 

Private Member Functions

void AccStatsFromOutput (const NnetExample &eg, const CuMatrixBase< BaseFloat > &nnet_output)
 

Private Attributes

BaseFloat rand_prune_
 
const Nnetnnet_
 
CachingOptimizingCompiler compiler_
 
LdaEstimate lda_stats_
 

Detailed Description

Definition at line 32 of file nnet3-acc-lda-stats.cc.

Constructor & Destructor Documentation

◆ NnetLdaStatsAccumulator()

NnetLdaStatsAccumulator ( BaseFloat  rand_prune,
const Nnet nnet 
)
inline

Definition at line 34 of file nnet3-acc-lda-stats.cc.

Member Function Documentation

◆ AccStats()

void AccStats ( const NnetExample eg)
inline

Definition at line 38 of file nnet3-acc-lda-stats.cc.

References NnetComputer::AcceptInputs(), NnetLdaStatsAccumulator::AccStatsFromOutput(), CachingOptimizingCompiler::Compile(), NnetLdaStatsAccumulator::compiler_, NnetComputeOptions::debug, kaldi::nnet3::GetComputationRequest(), kaldi::GetVerboseLevel(), NnetExample::io, and NnetLdaStatsAccumulator::nnet_.

Referenced by main().

38  {
39  ComputationRequest request;
40  bool need_backprop = false, store_stats = false;
41  GetComputationRequest(nnet_, eg, need_backprop, store_stats, &request);
42  const NnetComputation &computation = *(compiler_.Compile(request));
43  NnetComputeOptions options;
44  if (GetVerboseLevel() >= 3)
45  options.debug = true;
46  NnetComputer computer(options, computation, nnet_, NULL);
47 
48  computer.AcceptInputs(nnet_, eg.io);
49  computer.Run();
50  const CuMatrixBase<BaseFloat> &nnet_output = computer.GetOutput("output");
51  AccStatsFromOutput(eg, nnet_output);
52  }
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void AccStatsFromOutput(const NnetExample &eg, const CuMatrixBase< BaseFloat > &nnet_output)
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
Matrix for CUDA computing.
Definition: matrix-common.h:69
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...

◆ AccStatsFromOutput()

void AccStatsFromOutput ( const NnetExample eg,
const CuMatrixBase< BaseFloat > &  nnet_output 
)
inlineprivate

Definition at line 65 of file nnet3-acc-lda-stats.cc.

References LdaEstimate::Accumulate(), SparseVector< Real >::Data(), VectorBase< Real >::Dim(), LdaEstimate::Dim(), NnetIo::features, GeneralMatrix::GetMatrix(), GeneralMatrix::GetSparseMatrix(), rnnlm::i, LdaEstimate::Init(), NnetExample::io, KALDI_ASSERT, kaldi::kSparseMatrix, NnetLdaStatsAccumulator::lda_stats_, CuMatrixBase< Real >::NumCols(), GeneralMatrix::NumCols(), SparseVector< Real >::NumElements(), CuMatrixBase< Real >::NumRows(), GeneralMatrix::NumRows(), NnetLdaStatsAccumulator::rand_prune_, kaldi::RandPrune(), SparseMatrix< Real >::Row(), and GeneralMatrix::Type().

Referenced by NnetLdaStatsAccumulator::AccStats().

66  {
67  BaseFloat rand_prune = rand_prune_;
68  const NnetIo *output_supervision = NULL;
69  for (size_t i = 0; i < eg.io.size(); i++)
70  if (eg.io[i].name == "output")
71  output_supervision = &(eg.io[i]);
72  KALDI_ASSERT(output_supervision != NULL && "no output in eg named 'output'");
73  int32 num_rows = output_supervision->features.NumRows(),
74  num_pdfs = output_supervision->features.NumCols();
75  KALDI_ASSERT(num_rows == nnet_output.NumRows());
76  if (lda_stats_.Dim() == 0)
77  lda_stats_.Init(num_pdfs, nnet_output.NumCols());
78  if (output_supervision->features.Type() == kSparseMatrix) {
79  const SparseMatrix<BaseFloat> &smat =
80  output_supervision->features.GetSparseMatrix();
81  for (int32 r = 0; r < num_rows; r++) {
82  // the following, transferring row by row to CPU, would be wasteful
83  // if we actually were using a GPU, but we don't anticipate doing this
84  // in this program.
85  CuSubVector<BaseFloat> cu_row(nnet_output, r);
86  // "row" is actually just a redudant copy, since we're likely on CPU,
87  // but we're about to do an outer product, so this doesn't dominate.
88  Vector<BaseFloat> row(cu_row);
89 
90  const SparseVector<BaseFloat> &post(smat.Row(r));
91  const std::pair<MatrixIndexT, BaseFloat> *post_data = post.Data(),
92  *post_end = post_data + post.NumElements();
93  for (; post_data != post_end; ++post_data) {
94  MatrixIndexT pdf = post_data->first;
95  BaseFloat weight = post_data->second;
96  BaseFloat pruned_weight = RandPrune(weight, rand_prune);
97  if (pruned_weight != 0.0)
98  lda_stats_.Accumulate(row, pdf, pruned_weight);
99  }
100  }
101  } else {
102  Matrix<BaseFloat> output_mat;
103  output_supervision->features.GetMatrix(&output_mat);
104  for (int32 r = 0; r < num_rows; r++) {
105  // the following, transferring row by row to CPU, would be wasteful
106  // if we actually were using a GPU, but we don't anticipate doing this
107  // in this program.
108  CuSubVector<BaseFloat> cu_row(nnet_output, r);
109  // "row" is actually just a redudant copy, since we're likely on CPU,
110  // but we're about to do an outer product, so this doesn't dominate.
111  Vector<BaseFloat> row(cu_row);
112 
113  SubVector<BaseFloat> post(output_mat, r);
114  int32 num_pdfs = post.Dim();
115  for (int32 pdf = 0; pdf < num_pdfs; pdf++) {
116  BaseFloat weight = post(pdf);
117  BaseFloat pruned_weight = RandPrune(weight, rand_prune);
118  if (pruned_weight != 0.0)
119  lda_stats_.Accumulate(row, pdf, pruned_weight);
120  }
121  }
122  }
123  }
void Accumulate(const VectorBase< BaseFloat > &data, int32 class_id, BaseFloat weight=1.0)
Accumulates data.
Definition: lda-estimate.cc:45
void GetMatrix(Matrix< BaseFloat > *mat) const
Outputs the contents as a matrix.
int32 Dim() const
Returns the dimensionality of the feature vectors.
Definition: lda-estimate.h:66
Float RandPrune(Float post, BaseFloat prune_thresh, struct RandomState *state=NULL)
Definition: kaldi-math.h:174
kaldi::int32 int32
GeneralMatrix features
The features or labels.
Definition: nnet-example.h:46
void Init(int32 num_classes, int32 dimension)
Allocates memory for accumulators.
Definition: lda-estimate.cc:26
MatrixIndexT NumCols() const
float BaseFloat
Definition: kaldi-types.h:29
int32 MatrixIndexT
Definition: matrix-common.h:98
GeneralMatrixType Type() const
Returns the type of the matrix: kSparseMatrix, kCompressedMatrix or kFullMatrix.
MatrixIndexT NumElements() const
Returns the number of nonzero elements.
Definition: sparse-matrix.h:74
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
const SparseMatrix< BaseFloat > & GetSparseMatrix() const
Returns the contents as a SparseMatrix.
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
std::pair< MatrixIndexT, Real > * Data()
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
const SparseVector< Real > & Row(MatrixIndexT r) const

◆ WriteStats()

void WriteStats ( const std::string &  stats_wxfilename,
bool  binary 
)
inline

Definition at line 54 of file nnet3-acc-lda-stats.cc.

References KALDI_ERR, KALDI_LOG, NnetLdaStatsAccumulator::lda_stats_, LdaEstimate::TotCount(), and kaldi::WriteKaldiObject().

Referenced by main().

54  {
55  if (lda_stats_.TotCount() == 0) {
56  KALDI_ERR << "Accumulated no stats.";
57  } else {
58  WriteKaldiObject(lda_stats_, stats_wxfilename, binary);
59  KALDI_LOG << "Accumulated stats, soft frame count = "
60  << lda_stats_.TotCount() << ". Wrote to "
61  << stats_wxfilename;
62  }
63  }
#define KALDI_ERR
Definition: kaldi-error.h:147
double TotCount()
Return total count of the data.
Definition: lda-estimate.h:72
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:153

Member Data Documentation

◆ compiler_

CachingOptimizingCompiler compiler_
private

Definition at line 127 of file nnet3-acc-lda-stats.cc.

Referenced by NnetLdaStatsAccumulator::AccStats().

◆ lda_stats_

◆ nnet_

const Nnet& nnet_
private

Definition at line 126 of file nnet3-acc-lda-stats.cc.

Referenced by NnetLdaStatsAccumulator::AccStats().

◆ rand_prune_

BaseFloat rand_prune_
private

Definition at line 125 of file nnet3-acc-lda-stats.cc.

Referenced by NnetLdaStatsAccumulator::AccStatsFromOutput().


The documentation for this class was generated from the following file: