All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
compute-mce-scale.cc
Go to the documentation of this file.
1 // bin/compute-mce-scale.cc
2 
3 // Copyright 2009-2011 Chao Weng
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 
24 
25 int main(int argc, char *argv[]) {
26  try {
27  using namespace kaldi;
28  typedef kaldi::int32 int32;
29 
30  const char *usage =
31  "compute the scale of MCE, which is used to scale posteriors\n"
32  "Usage: compute-mce-scale [option] num-score-rspecifier "
33  "den-score-rspecifier out-scale-wspecifier\n";
34 
35  ParseOptions po(usage);
36  kaldi::BaseFloat mce_alpha = 1.0, mce_beta = 0.0;
37  po.Register("mce-alpha", &mce_alpha, "alpha parameter for sigmoid");
38  po.Register("mce-beta", &mce_beta, "beta parameter for sigmoid");
39  po.Read(argc, argv);
40 
41  if (po.NumArgs() != 3) {
42  po.PrintUsage();
43  exit(1);
44  }
45 
46  std::string num_score_rspecifier = po.GetArg(1);
47  std::string den_score_rspecifier = po.GetArg(2);
48  std::string scale_wspecifier = po.GetArg(3);
49 
50  kaldi::SequentialBaseFloatReader num_score_reader(num_score_rspecifier);
51  kaldi::RandomAccessBaseFloatReader den_score_reader(den_score_rspecifier);
52  kaldi::BaseFloatWriter scale_writer(scale_wspecifier);
53 
54  int32 num_scaled = 0, num_no_score = 0;
55  double tot_sigmoid = 0.0;
56 
57  for (; !num_score_reader.Done(); num_score_reader.Next()) {
58  std::string key = num_score_reader.Key();
59  kaldi::BaseFloat num_score = num_score_reader.Value();
60  num_score_reader.FreeCurrent();
61  if (!den_score_reader.HasKey(key)) {
62  num_no_score++;
63  } else {
64  // calculate the sigmoid scaling factor for MCE
65  // Note: the derivative is:
66  // \alpha * sigmoid(num - den) * (1 - sigmoid(num - den))
67  // but the make the scale be:
68  // 4 * sigmoid(num - den) * (1 - sigmoid(num - den))
69  // which is just multiplying by 4/alpha; this means
70  // that the maximum value the scale can have is 1, which
71  // means it's more comparable with MMI/MPE.
72  BaseFloat den_score = den_score_reader.Value(key);
73  BaseFloat score_difference = mce_alpha * (num_score - den_score) + mce_beta;
74  BaseFloat sigmoid_difference = 1.0 / (1.0 + Exp(score_difference));
75  // It might be more natural to make the scale
76  //
77  BaseFloat scale = 4.0 * sigmoid_difference * (1 - sigmoid_difference);
78  scale_writer.Write(key, scale);
79  num_scaled++;
80  tot_sigmoid += sigmoid_difference;
81  }
82  }
83  KALDI_LOG << num_scaled << " scales generated; " << num_no_score
84  << " had no num/den scores.";
85  KALDI_LOG << "Overall MCE objective function per utterance is "
86  << (tot_sigmoid/num_scaled) << " over "
87  << num_scaled << " utterance. [Note: should go down]";
88  return (num_scaled != 0 ? 0 : 1);
89  } catch(const std::exception &e) {
90  std::cerr << e.what();
91  return -1;
92  }
93 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
double Exp(double x)
Definition: kaldi-math.h:83
void Write(const std::string &key, const T &value) const
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:366
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
bool HasKey(const std::string &key)
int main(int argc, char *argv[])
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_LOG
Definition: kaldi-error.h:133
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.