All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
compute-mce-scale.cc File Reference
Include dependency graph for compute-mce-scale.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 25 of file compute-mce-scale.cc.

References SequentialTableReader< Holder >::Done(), kaldi::Exp(), SequentialTableReader< Holder >::FreeCurrent(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), KALDI_LOG, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), ParseOptions::Register(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

25  {
26  try {
27  using namespace kaldi;
28  typedef kaldi::int32 int32;
29 
30  const char *usage =
31  "compute the scale of MCE, which is used to scale posteriors\n"
32  "Usage: compute-mce-scale [option] num-score-rspecifier "
33  "den-score-rspecifier out-scale-wspecifier\n";
34 
35  ParseOptions po(usage);
36  kaldi::BaseFloat mce_alpha = 1.0, mce_beta = 0.0;
37  po.Register("mce-alpha", &mce_alpha, "alpha parameter for sigmoid");
38  po.Register("mce-beta", &mce_beta, "beta parameter for sigmoid");
39  po.Read(argc, argv);
40 
41  if (po.NumArgs() != 3) {
42  po.PrintUsage();
43  exit(1);
44  }
45 
46  std::string num_score_rspecifier = po.GetArg(1);
47  std::string den_score_rspecifier = po.GetArg(2);
48  std::string scale_wspecifier = po.GetArg(3);
49 
50  kaldi::SequentialBaseFloatReader num_score_reader(num_score_rspecifier);
51  kaldi::RandomAccessBaseFloatReader den_score_reader(den_score_rspecifier);
52  kaldi::BaseFloatWriter scale_writer(scale_wspecifier);
53 
54  int32 num_scaled = 0, num_no_score = 0;
55  double tot_sigmoid = 0.0;
56 
57  for (; !num_score_reader.Done(); num_score_reader.Next()) {
58  std::string key = num_score_reader.Key();
59  kaldi::BaseFloat num_score = num_score_reader.Value();
60  num_score_reader.FreeCurrent();
61  if (!den_score_reader.HasKey(key)) {
62  num_no_score++;
63  } else {
64  // calculate the sigmoid scaling factor for MCE
65  // Note: the derivative is:
66  // \alpha * sigmoid(num - den) * (1 - sigmoid(num - den))
67  // but the make the scale be:
68  // 4 * sigmoid(num - den) * (1 - sigmoid(num - den))
69  // which is just multiplying by 4/alpha; this means
70  // that the maximum value the scale can have is 1, which
71  // means it's more comparable with MMI/MPE.
72  BaseFloat den_score = den_score_reader.Value(key);
73  BaseFloat score_difference = mce_alpha * (num_score - den_score) + mce_beta;
74  BaseFloat sigmoid_difference = 1.0 / (1.0 + Exp(score_difference));
75  // It might be more natural to make the scale
76  //
77  BaseFloat scale = 4.0 * sigmoid_difference * (1 - sigmoid_difference);
78  scale_writer.Write(key, scale);
79  num_scaled++;
80  tot_sigmoid += sigmoid_difference;
81  }
82  }
83  KALDI_LOG << num_scaled << " scales generated; " << num_no_score
84  << " had no num/den scores.";
85  KALDI_LOG << "Overall MCE objective function per utterance is "
86  << (tot_sigmoid/num_scaled) << " over "
87  << num_scaled << " utterance. [Note: should go down]";
88  return (num_scaled != 0 ? 0 : 1);
89  } catch(const std::exception &e) {
90  std::cerr << e.what();
91  return -1;
92  }
93 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
double Exp(double x)
Definition: kaldi-math.h:83
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:366
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_LOG
Definition: kaldi-error.h:133