nnet3-acc-lda-stats.cc File Reference
Include dependency graph for nnet3-acc-lda-stats.cc:

Go to the source code of this file.

Classes

class  NnetLdaStatsAccumulator
 

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet3
 

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 135 of file nnet3-acc-lda-stats.cc.

References NnetLdaStatsAccumulator::AccStats(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), KALDI_LOG, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), SequentialTableReader< Holder >::Value(), and NnetLdaStatsAccumulator::WriteStats().

135  {
136  try {
137  using namespace kaldi;
138  using namespace kaldi::nnet3;
139  typedef kaldi::int32 int32;
140  typedef kaldi::int64 int64;
141 
142  const char *usage =
143  "Accumulate statistics in the same format as acc-lda (i.e. stats for\n"
144  "estimation of LDA and similar types of transform), starting from nnet\n"
145  "training examples. This program puts the features through the network,\n"
146  "and the network output will be the features; the supervision in the\n"
147  "training examples is used for the class labels. Used in obtaining\n"
148  "feature transforms that help nnet training work better.\n"
149  "\n"
150  "Usage: nnet3-acc-lda-stats [options] <raw-nnet-in> <training-examples-in> <lda-stats-out>\n"
151  "e.g.:\n"
152  "nnet3-acc-lda-stats 0.raw ark:1.egs 1.acc\n"
153  "See also: nnet-get-feature-transform\n";
154 
155  bool binary_write = true;
156  BaseFloat rand_prune = 0.0;
157 
158  ParseOptions po(usage);
159  po.Register("binary", &binary_write, "Write output in binary mode");
160  po.Register("rand-prune", &rand_prune,
161  "Randomized pruning threshold for posteriors");
162 
163  po.Read(argc, argv);
164 
165  if (po.NumArgs() != 3) {
166  po.PrintUsage();
167  exit(1);
168  }
169 
170  std::string nnet_rxfilename = po.GetArg(1),
171  examples_rspecifier = po.GetArg(2),
172  lda_accs_wxfilename = po.GetArg(3);
173 
174  Nnet nnet;
175  ReadKaldiObject(nnet_rxfilename, &nnet);
176 
177  NnetLdaStatsAccumulator accumulator(rand_prune, nnet);
178 
179  int64 num_egs = 0;
180 
181  SequentialNnetExampleReader example_reader(examples_rspecifier);
182  for (; !example_reader.Done(); example_reader.Next(), num_egs++)
183  accumulator.AccStats(example_reader.Value());
184 
185  KALDI_LOG << "Processed " << num_egs << " examples.";
186  // the next command will die if we accumulated no stats.
187  accumulator.WriteStats(lda_accs_wxfilename, binary_write);
188 
189  return 0;
190  } catch(const std::exception &e) {
191  std::cerr << e.what() << '\n';
192  return -1;
193  }
194 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_LOG
Definition: kaldi-error.h:153