compute-wer-bootci.cc
Go to the documentation of this file.
1 // bin/compute-wer-bootci.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 // 2014 Johns Hopkins University (authors: Jan Trmal, Daniel Povey)
5 // 2015 Brno Universiry of technology (author: Karel Vesely)
6 // 2016 Nicolas Serrano
7 
8 // See ../../COPYING for clarification regarding multiple authors
9 //
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 //
14 // http://www.apache.org/licenses/LICENSE-2.0
15 //
16 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
18 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
19 // MERCHANTABLITY OR NON-INFRINGEMENT.
20 // See the Apache 2 License for the specific language governing permissions and
21 // limitations under the License.
22 
23 #include "base/kaldi-common.h"
24 #include "util/common-utils.h"
25 #include "util/parse-options.h"
26 #include "tree/context-dep.h"
27 #include "util/edit-distance.h"
28 #include "base/kaldi-math.h"
29 
30 namespace kaldi {
31 
32 void GetEditsSingleHyp( const std::string &hyp_rspecifier,
33  const std::string &ref_rspecifier,
34  const std::string &mode,
35  std::vector<std::pair<int32, int32> > & edit_word_per_hyp) {
36 
37  // Both text and integers are loaded as vector of strings,
38  SequentialTokenVectorReader ref_reader(ref_rspecifier);
39  RandomAccessTokenVectorReader hyp_reader(hyp_rspecifier);
40  int32 num_words = 0, word_errs = 0, num_ins = 0, num_del = 0, num_sub = 0;
41 
42  // Main loop, store WER stats per hyp,
43  for (; !ref_reader.Done(); ref_reader.Next()) {
44  std::string key = ref_reader.Key();
45  const std::vector<std::string> &ref_sent = ref_reader.Value();
46  std::vector<std::string> hyp_sent;
47  if (!hyp_reader.HasKey(key)) {
48  if (mode == "strict")
49  KALDI_ERR << "No hypothesis for key " << key << " and strict "
50  "mode specifier.";
51  if (mode == "present") // do not score this one.
52  continue;
53  } else {
54  hyp_sent = hyp_reader.Value(key);
55  }
56  num_words = ref_sent.size();
57  word_errs = LevenshteinEditDistance(ref_sent, hyp_sent,
58  &num_ins, &num_del, &num_sub);
59  edit_word_per_hyp.push_back(std::pair<int32, int32>(word_errs, num_words));
60  }
61 }
62 
63 void GetEditsDualHyp(const std::string &hyp_rspecifier,
64  const std::string &hyp_rspecifier2,
65  const std::string &ref_rspecifier,
66  const std::string &mode,
67  std::vector<std::pair<int32, int32> > & edit_word_per_hyp,
68  std::vector<std::pair<int32, int32> > & edit_word_per_hyp2) {
69 
70  // Both text and integers are loaded as vector of strings,
71  SequentialTokenVectorReader ref_reader(ref_rspecifier);
72  RandomAccessTokenVectorReader hyp_reader(hyp_rspecifier);
73  RandomAccessTokenVectorReader hyp_reader2(hyp_rspecifier2);
74  int32 num_words = 0, word_errs = 0,
75  num_ins = 0, num_del = 0, num_sub = 0;
76 
77  // Main loop, store WER stats per hyp,
78  for (; !ref_reader.Done(); ref_reader.Next()) {
79  std::string key = ref_reader.Key();
80  const std::vector<std::string> &ref_sent = ref_reader.Value();
81  std::vector<std::string> hyp_sent, hyp_sent2;
82  if (mode == "strict" &&
83  (!hyp_reader.HasKey(key) || !hyp_reader2.HasKey(key))) {
84  KALDI_ERR << "No hypothesis for key " << key << " in both transcripts "
85  "comparison is not possible.";
86  } else if (mode == "present" &&
87  (!hyp_reader.HasKey(key) || !hyp_reader2.HasKey(key)))
88  continue;
89 
90  num_words = ref_sent.size();
91 
92  //all mode, if a hypothesis is not present, consider as an error
93  if(hyp_reader.HasKey(key)){
94  hyp_sent = hyp_reader.Value(key);
95  word_errs = LevenshteinEditDistance(ref_sent, hyp_sent,
96  &num_ins, &num_del, &num_sub);
97  }
98  else
99  word_errs = num_words;
100  edit_word_per_hyp.push_back(std::pair<int32, int32>(word_errs, num_words));
101 
102  if(hyp_reader2.HasKey(key)){
103  hyp_sent2 = hyp_reader2.Value(key);
104  word_errs = LevenshteinEditDistance(ref_sent, hyp_sent2,
105  &num_ins, &num_del, &num_sub);
106  }
107  else
108  word_errs = num_words;
109  edit_word_per_hyp2.push_back(std::pair<int32, int32>(word_errs, num_words));
110  }
111 }
112 
114  const std::vector<std::pair<int32, int32> > & edit_word_per_hyp,
115  int32 replications,
116  BaseFloat *mean, BaseFloat *interval) {
117  BaseFloat wer_accum = 0.0, wer_mult_accum = 0.0;
118 
119  for (int32 i = 0; i < replications; ++i) {
120  int32 num_words = 0, word_errs = 0;
121  for (int32 j = 0; j < edit_word_per_hyp.size(); ++j) {
122  int32 random_pos = kaldi::RandInt(0, edit_word_per_hyp.size() - 1);
123  word_errs += edit_word_per_hyp[random_pos].first;
124  num_words += edit_word_per_hyp[random_pos].second;
125  }
126 
127  BaseFloat wer_rep = static_cast<BaseFloat>(word_errs) / num_words;
128  wer_accum += wer_rep;
129  wer_mult_accum += wer_rep*wer_rep;
130  }
131 
132  // Compute mean WER and std WER
133  *mean = wer_accum / replications;
134  *interval = 1.96*sqrt(wer_mult_accum/replications-(*mean)*(*mean));
135 }
136 
138  const std::vector<std::pair<int32, int32> > & edit_word_per_hyp,
139  const std::vector<std::pair<int32, int32> > & edit_word_per_hyp2,
140  int32 replications, BaseFloat *p_improv) {
141  int32 improv_accum = 0.0;
142 
143  for (int32 i = 0; i < replications; ++i) {
144  int32 word_errs = 0;
145  for (int32 j = 0; j < edit_word_per_hyp.size(); ++j) {
146  int32 random_pos = kaldi::RandInt(0, edit_word_per_hyp.size() - 1);
147  word_errs += edit_word_per_hyp[random_pos].first -
148  edit_word_per_hyp2[random_pos].first;
149  }
150  if(word_errs > 0)
151  ++improv_accum;
152  }
153  // Compute mean WER and std WER
154  *p_improv = static_cast<BaseFloat>(improv_accum) / replications;
155 }
156 
157 } //namespace kaldi
158 
159 int main(int argc, char *argv[]) {
160  using namespace kaldi;
161  typedef kaldi::int32 int32;
162 
163  try {
164  const char *usage =
165  "Compute a bootstrapping of WER to extract the 95% confidence interval.\n"
166  "Take a reference and a transcription file, in integer or text format,\n"
167  "and outputs overall WER statistics to standard output along with its\n"
168  "confidence interval using the bootstrap method of Bisani and Ney.\n"
169  "If a second transcription file corresponding to the same reference is\n"
170  "provided, a bootstrap comparison of the two transcription is performed\n"
171  "to estimate the probability of improvement.\n"
172  "\n"
173  "Usage: compute-wer-bootci [options] <ref-rspecifier> <hyp-rspecifier> [<hyp2-rspecifier>]\n"
174  "E.g.: compute-wer-bootci --mode=present ark:data/train/text ark:hyp_text\n"
175  "or compute-wer-bootci ark:data/train/text ark:hyp_text ark:hyp_text2\n"
176  "See also: compute-wer\n";
177 
178  ParseOptions po(usage);
179 
180  std::string mode = "strict";
181  po.Register("mode", &mode,
182  "Scoring mode: \"present\"|\"all\"|\"strict\":\n"
183  " \"present\" means score those we have transcriptions for\n"
184  " \"all\" means treat absent transcriptions as empty\n"
185  " \"strict\" means die if all in ref not also in hyp");
186 
187  int32 replications = 10000;
188  po.Register("replications", &replications,
189  "Number of replications to compute the intervals");
190 
191  po.Read(argc, argv);
192 
193  if (po.NumArgs() < 2 || po.NumArgs() > 3) {
194  po.PrintUsage();
195  exit(1);
196  }
197 
198  std::string ref_rspecifier = po.GetArg(1);
199  std::string hyp_rspecifier = po.GetArg(2);
200  std::string hyp2_rspecifier = (po.NumArgs() == 3?po.GetArg(3):"");
201 
202  if (mode != "strict" && mode != "present" && mode != "all") {
203  KALDI_ERR <<
204  "--mode option invalid: expected \"present\"|\"all\"|\"strict\", got "
205  << mode;
206  }
207 
208  //Get editions per each utterance
209  std::vector<std::pair<int32, int32> > edit_word_per_hyp, edit_word_per_hyp2;
210  if(hyp2_rspecifier.empty())
211  GetEditsSingleHyp(hyp_rspecifier, ref_rspecifier, mode, edit_word_per_hyp);
212  else
213  GetEditsDualHyp(hyp_rspecifier, hyp2_rspecifier, ref_rspecifier, mode,
214  edit_word_per_hyp, edit_word_per_hyp2);
215 
216  //Extract WER for a number of replications of the same size
217  //as the hypothesis extracted
218  BaseFloat mean_wer = 0.0, interval = 0.0,
219  mean_wer2 = 0.0, interval2 = 0.0,
220  p_improv = 0.0;
221 
222  GetBootstrapWERInterval(edit_word_per_hyp, replications,
223  &mean_wer, &interval);
224 
225  if(!hyp2_rspecifier.empty()) {
226  GetBootstrapWERInterval(edit_word_per_hyp2, replications,
227  &mean_wer2, &interval2);
228 
229  GetBootstrapWERTwoSystemComparison(edit_word_per_hyp, edit_word_per_hyp2,
230  replications, &p_improv);
231  }
232 
233  // Print the output,
234  std::cout.precision(2);
235  std::cerr.precision(2);
236  std::cout << "Set1: %WER " << std::fixed << 100*mean_wer <<
237  " 95% Conf Interval [ " << 100*mean_wer-100*interval <<
238  ", " << 100*mean_wer+100*interval << " ]" << '\n';
239 
240  if(!hyp2_rspecifier.empty()) {
241  std::cout << "Set2: %WER " << std::fixed << 100*mean_wer2 <<
242  " 95% Conf Interval [ " << 100*mean_wer2-100*interval2 <<
243  ", " << 100*mean_wer2+100*interval2 << " ]" << '\n';
244 
245  std::cout << "Probability of Set2 improving Set1: " << std::fixed <<
246  100*p_improv << '\n';
247  }
248 
249  return 0;
250  } catch(const std::exception &e) {
251  std::cerr << e.what();
252  return -1;
253  }
254 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void GetEditsSingleHyp(const std::string &hyp_rspecifier, const std::string &ref_rspecifier, const std::string &mode, std::vector< std::pair< int32, int32 > > &edit_word_per_hyp)
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
void GetEditsDualHyp(const std::string &hyp_rspecifier, const std::string &hyp_rspecifier2, const std::string &ref_rspecifier, const std::string &mode, std::vector< std::pair< int32, int32 > > &edit_word_per_hyp, std::vector< std::pair< int32, int32 > > &edit_word_per_hyp2)
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
int32 LevenshteinEditDistance(const std::vector< T > &a, const std::vector< T > &b)
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void GetBootstrapWERInterval(const std::vector< std::pair< int32, int32 > > &edit_word_per_hyp, int32 replications, BaseFloat *mean, BaseFloat *interval)
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void GetBootstrapWERTwoSystemComparison(const std::vector< std::pair< int32, int32 > > &edit_word_per_hyp, const std::vector< std::pair< int32, int32 > > &edit_word_per_hyp2, int32 replications, BaseFloat *p_improv)
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95