nnet3-acc-lda-stats.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-acc-lda-stats.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
23 #include "nnet3/nnet-nnet.h"
25 #include "nnet3/nnet-optimize.h"
26 #include "transform/lda-estimate.h"
27 
28 
29 namespace kaldi {
30 namespace nnet3 {
31 
33  public:
35  const Nnet &nnet):
36  rand_prune_(rand_prune), nnet_(nnet), compiler_(nnet) { }
37 
38  void AccStats(const NnetExample &eg) {
39  ComputationRequest request;
40  bool need_backprop = false, store_stats = false;
41  GetComputationRequest(nnet_, eg, need_backprop, store_stats, &request);
42  const NnetComputation &computation = *(compiler_.Compile(request));
43  NnetComputeOptions options;
44  if (GetVerboseLevel() >= 3)
45  options.debug = true;
46  NnetComputer computer(options, computation, nnet_, NULL);
47 
48  computer.AcceptInputs(nnet_, eg.io);
49  computer.Run();
50  const CuMatrixBase<BaseFloat> &nnet_output = computer.GetOutput("output");
51  AccStatsFromOutput(eg, nnet_output);
52  }
53 
54  void WriteStats(const std::string &stats_wxfilename, bool binary) {
55  if (lda_stats_.TotCount() == 0) {
56  KALDI_ERR << "Accumulated no stats.";
57  } else {
58  WriteKaldiObject(lda_stats_, stats_wxfilename, binary);
59  KALDI_LOG << "Accumulated stats, soft frame count = "
60  << lda_stats_.TotCount() << ". Wrote to "
61  << stats_wxfilename;
62  }
63  }
64  private:
66  const CuMatrixBase<BaseFloat> &nnet_output) {
67  BaseFloat rand_prune = rand_prune_;
68  const NnetIo *output_supervision = NULL;
69  for (size_t i = 0; i < eg.io.size(); i++)
70  if (eg.io[i].name == "output")
71  output_supervision = &(eg.io[i]);
72  KALDI_ASSERT(output_supervision != NULL && "no output in eg named 'output'");
73  int32 num_rows = output_supervision->features.NumRows(),
74  num_pdfs = output_supervision->features.NumCols();
75  KALDI_ASSERT(num_rows == nnet_output.NumRows());
76  if (lda_stats_.Dim() == 0)
77  lda_stats_.Init(num_pdfs, nnet_output.NumCols());
78  if (output_supervision->features.Type() == kSparseMatrix) {
79  const SparseMatrix<BaseFloat> &smat =
80  output_supervision->features.GetSparseMatrix();
81  for (int32 r = 0; r < num_rows; r++) {
82  // the following, transferring row by row to CPU, would be wasteful
83  // if we actually were using a GPU, but we don't anticipate doing this
84  // in this program.
85  CuSubVector<BaseFloat> cu_row(nnet_output, r);
86  // "row" is actually just a redudant copy, since we're likely on CPU,
87  // but we're about to do an outer product, so this doesn't dominate.
88  Vector<BaseFloat> row(cu_row);
89 
90  const SparseVector<BaseFloat> &post(smat.Row(r));
91  const std::pair<MatrixIndexT, BaseFloat> *post_data = post.Data(),
92  *post_end = post_data + post.NumElements();
93  for (; post_data != post_end; ++post_data) {
94  MatrixIndexT pdf = post_data->first;
95  BaseFloat weight = post_data->second;
96  BaseFloat pruned_weight = RandPrune(weight, rand_prune);
97  if (pruned_weight != 0.0)
98  lda_stats_.Accumulate(row, pdf, pruned_weight);
99  }
100  }
101  } else {
102  Matrix<BaseFloat> output_mat;
103  output_supervision->features.GetMatrix(&output_mat);
104  for (int32 r = 0; r < num_rows; r++) {
105  // the following, transferring row by row to CPU, would be wasteful
106  // if we actually were using a GPU, but we don't anticipate doing this
107  // in this program.
108  CuSubVector<BaseFloat> cu_row(nnet_output, r);
109  // "row" is actually just a redudant copy, since we're likely on CPU,
110  // but we're about to do an outer product, so this doesn't dominate.
111  Vector<BaseFloat> row(cu_row);
112 
113  SubVector<BaseFloat> post(output_mat, r);
114  int32 num_pdfs = post.Dim();
115  for (int32 pdf = 0; pdf < num_pdfs; pdf++) {
116  BaseFloat weight = post(pdf);
117  BaseFloat pruned_weight = RandPrune(weight, rand_prune);
118  if (pruned_weight != 0.0)
119  lda_stats_.Accumulate(row, pdf, pruned_weight);
120  }
121  }
122  }
123  }
124 
126  const Nnet &nnet_;
129 
130 };
131 
132 }
133 }
134 
135 int main(int argc, char *argv[]) {
136  try {
137  using namespace kaldi;
138  using namespace kaldi::nnet3;
139  typedef kaldi::int32 int32;
140  typedef kaldi::int64 int64;
141 
142  const char *usage =
143  "Accumulate statistics in the same format as acc-lda (i.e. stats for\n"
144  "estimation of LDA and similar types of transform), starting from nnet\n"
145  "training examples. This program puts the features through the network,\n"
146  "and the network output will be the features; the supervision in the\n"
147  "training examples is used for the class labels. Used in obtaining\n"
148  "feature transforms that help nnet training work better.\n"
149  "\n"
150  "Usage: nnet3-acc-lda-stats [options] <raw-nnet-in> <training-examples-in> <lda-stats-out>\n"
151  "e.g.:\n"
152  "nnet3-acc-lda-stats 0.raw ark:1.egs 1.acc\n"
153  "See also: nnet-get-feature-transform\n";
154 
155  bool binary_write = true;
156  BaseFloat rand_prune = 0.0;
157 
158  ParseOptions po(usage);
159  po.Register("binary", &binary_write, "Write output in binary mode");
160  po.Register("rand-prune", &rand_prune,
161  "Randomized pruning threshold for posteriors");
162 
163  po.Read(argc, argv);
164 
165  if (po.NumArgs() != 3) {
166  po.PrintUsage();
167  exit(1);
168  }
169 
170  std::string nnet_rxfilename = po.GetArg(1),
171  examples_rspecifier = po.GetArg(2),
172  lda_accs_wxfilename = po.GetArg(3);
173 
174  Nnet nnet;
175  ReadKaldiObject(nnet_rxfilename, &nnet);
176 
177  NnetLdaStatsAccumulator accumulator(rand_prune, nnet);
178 
179  int64 num_egs = 0;
180 
181  SequentialNnetExampleReader example_reader(examples_rspecifier);
182  for (; !example_reader.Done(); example_reader.Next(), num_egs++)
183  accumulator.AccStats(example_reader.Value());
184 
185  KALDI_LOG << "Processed " << num_egs << " examples.";
186  // the next command will die if we accumulated no stats.
187  accumulator.WriteStats(lda_accs_wxfilename, binary_write);
188 
189  return 0;
190  } catch(const std::exception &e) {
191  std::cerr << e.what() << '\n';
192  return -1;
193  }
194 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
void Accumulate(const VectorBase< BaseFloat > &data, int32 class_id, BaseFloat weight=1.0)
Accumulates data.
Definition: lda-estimate.cc:45
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
Class for computing linear discriminant analysis (LDA) transform.
Definition: lda-estimate.h:57
void GetMatrix(Matrix< BaseFloat > *mat) const
Outputs the contents as a matrix.
int32 Dim() const
Returns the dimensionality of the feature vectors.
Definition: lda-estimate.h:66
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
Float RandPrune(Float post, BaseFloat prune_thresh, struct RandomState *state=NULL)
Definition: kaldi-math.h:174
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
kaldi::int32 int32
GeneralMatrix features
The features or labels.
Definition: nnet-example.h:46
void Init(int32 num_classes, int32 dimension)
Allocates memory for accumulators.
Definition: lda-estimate.cc:26
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
MatrixIndexT NumCols() const
void AccStatsFromOutput(const NnetExample &eg, const CuMatrixBase< BaseFloat > &nnet_output)
float BaseFloat
Definition: kaldi-types.h:29
int32 MatrixIndexT
Definition: matrix-common.h:98
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void AccStats(const NnetExample &eg)
GeneralMatrixType Type() const
Returns the type of the matrix: kSparseMatrix, kCompressedMatrix or kFullMatrix.
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
MatrixIndexT NumElements() const
Returns the number of nonzero elements.
Definition: sparse-matrix.h:74
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
int NumArgs() const
Number of positional parameters (c.f. argc-1).
Matrix for CUDA computing.
Definition: matrix-common.h:69
double TotCount()
Return total count of the data.
Definition: lda-estimate.h:72
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
A class representing a vector.
Definition: kaldi-vector.h:406
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
void WriteStats(const std::string &stats_wxfilename, bool binary)
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
NnetLdaStatsAccumulator(BaseFloat rand_prune, const Nnet &nnet)
const SparseMatrix< BaseFloat > & GetSparseMatrix() const
Returns the contents as a SparseMatrix.
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
#define KALDI_LOG
Definition: kaldi-error.h:153
std::pair< MatrixIndexT, Real > * Data()
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...
const SparseVector< Real > & Row(MatrixIndexT r) const