compute-atwv.cc
Go to the documentation of this file.
1 // kwsbin/compute-atwv.cc
2 
3 // Copyright (c) 2015, Johns Hopkins University (Yenda Trmal<jtrmal@gmail.com>)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include <algorithm>
22 #include <iomanip> // std::setw
23 
24 #include "base/kaldi-common.h"
25 #include "util/common-utils.h"
26 #include "util/stl-utils.h"
27 #include "kws/kws-scoring.h"
28 
29 
30 int main(int argc, char *argv[]) {
31  try {
32  using namespace kaldi;
33 
34  typedef kaldi::int32 int32;
35  typedef kaldi::uint32 uint32;
36  typedef kaldi::uint64 uint64;
37 
38  const char *usage = "Computes the Actual Term-Weighted Value and prints it."
39  "\n"
40  "Usage: \n"
41  " compute-atwv [options] <nof-trials> <ref-rspecifier>"
42  " <hyp-rspecifier> [alignment-csv-filename]\n"
43  "e.g.: \n"
44  " compute-atwv 32485.4 ark:ref.1 ark:hyp.1 ali.csv\n"
45  "or: \n"
46  " compute-atwv 32485.4 ark:ref.1 ark:hyp.1\n"
47  "\n"
48  "NOTES: \n"
49  " a) the number of trials is usually equal to the size of the searched\n"
50  " collection in seconds\n"
51  " b the ref-rspecifier/hyp-rspecifier are the kaldi IO specifiers \n"
52  " for both the reference and the hypotheses (found hits), "
53  " respectively The format is the same for both of them. Each line\n"
54  " is of the following format\n"
55  "\n"
56  " <KW-ID> <utterance-id> <start-frame> <end-frame> <score>\n\n"
57  " e.g.:\n\n"
58  " KW106-189 348 459 560 0.8\n"
59  "\n"
60  " b) the alignment-csv-filename is an optional parameter. \n"
61  " If present, the alignment i.e. detailed information about what \n"
62  " hypotheses match up with which reference entries will be \n"
63  " generated. The alignemnt file format is equivalent to \n"
64  " the alignment file produced using the F4DE tool. However, we do"
65  " not set some fields and the utterance identifiers are numeric.\n"
66  " You can use the script utils/int2sym.pl and the utterance and \n"
67  " keyword maps to convert the numerical ids into text form\n"
68  " c) the scores are expected to be probabilities. Please note that\n"
69  " the output from the kws-search is in -log(probability).\n"
70  " d) compute-atwv does not perform any score normalization (it's just\n"
71  " for scoring purposes). Without score normalization/calibration\n"
72  " the performance of the search will be quite poor.\n";
73 
74  ParseOptions po(usage);
75  KwsTermsAlignerOptions ali_opts;
76  TwvMetricsOptions twv_opts;
77  int frames_per_sec = 100;
78 
79  ali_opts.Register(&po);
80  twv_opts.Register(&po);
81  po.Register("frames-per-sec", &frames_per_sec,
82  "Number of feature vector frames per second. This is used only when"
83  "writing the alignment to a file");
84 
85  po.Read(argc, argv);
86 
87  if (po.NumArgs() < 3 || po.NumArgs() > 4) {
88  po.PrintUsage();
89  exit(1);
90  }
91 
92  if (!kaldi::ConvertStringToReal(po.GetArg(1), &twv_opts.audio_duration)) {
93  KALDI_ERR << "The duration parameter is not a number";
94  }
95  if (twv_opts.audio_duration <= 0) {
96  KALDI_ERR << "The duration is either negative or zero";
97  }
98 
99  KwsTermsAligner aligner(ali_opts);
100  TwvMetrics twv_scores(twv_opts);
101 
102  std::string ref_rspecifier = po.GetArg(2),
103  hyp_rspecifier = po.GetArg(3),
104  ali_output = po.GetOptArg(4);
105 
107  ref_reader(ref_rspecifier);
108 
109  for (; !ref_reader.Done(); ref_reader.Next()) {
110  std::string kwid = ref_reader.Key();
111  std::vector<double> vals = ref_reader.Value();
112  if (vals.size() != 4) {
113  KALDI_ERR << "Incorrect format of the reference file"
114  << " -- 4 entries expected, " << vals.size() << " given!\n"
115  << "Key: " << kwid;
116  }
117  KwsTerm inst(kwid, vals);
118  aligner.AddRef(inst);
119  }
120 
122  hyp_reader(hyp_rspecifier);
123 
124  for (; !hyp_reader.Done(); hyp_reader.Next()) {
125  std::string kwid = hyp_reader.Key();
126  std::vector<double> vals = hyp_reader.Value();
127  if (vals.size() != 4) {
128  KALDI_ERR << "Incorrect format of the hypotheses file"
129  << " -- 4 entries expected, " << vals.size() << " given!\n"
130  << "Key: " << kwid;
131  }
132  KwsTerm inst(kwid, vals);
133  aligner.AddHyp(inst);
134  }
135 
136  KALDI_LOG << "Read " << aligner.nof_hyps() << " hypotheses";
137  KALDI_LOG << "Read " << aligner.nof_refs() << " references";
138  KwsAlignment ali = aligner.AlignTerms();
139 
140  if (ali_output != "") {
141  std::fstream fs;
142  fs.open(ali_output.c_str(), std::fstream::out);
143  ali.WriteCsv(fs, frames_per_sec);
144  fs.close();
145  }
146 
147  TwvMetrics scores(twv_opts);
148  scores.AddAlignment(ali);
149 
150  std::cout << "aproximate ATWV = "
151  << std::fixed << std::setprecision(4)
152  << scores.Atwv() << std::endl;
153  std::cout << "aproximate STWV = "
154  << std::fixed << std::setprecision(4)
155  << scores.Stwv() << std::endl;
156 
157  float mtwv, mtwv_threshold, otwv;
158  scores.GetOracleMeasures(&mtwv, &mtwv_threshold, &otwv);
159 
160  std::cout << "aproximate MTWV = "
161  << std::fixed << std::setprecision(4)
162  << mtwv << std::endl;
163  std::cout << "aproximate MTWV threshold = "
164  << std::fixed << std::setprecision(4)
165  << mtwv_threshold << std::endl;
166  std::cout << "aproximate OTWV = "
167  << std::fixed << std::setprecision(4)
168  << otwv << std::endl;
169  } catch(const std::exception &e) {
170  std::cerr << e.what();
171  return -1;
172  }
173 }
void AddRef(const KwsTerm &ref)
Definition: kws-scoring.h:138
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AddHyp(const KwsTerm &hyp)
Definition: kws-scoring.h:143
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int main(int argc, char *argv[])
Definition: compute-atwv.cc:30
void AddAlignment(const KwsAlignment &ali)
Definition: kws-scoring.cc:392
void WriteCsv(std::iostream &os, const float frames_per_sec)
Definition: kws-scoring.cc:244
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Register(OptionsItf *opts)
Definition: kws-scoring.cc:297
void GetOracleMeasures(float *final_mtwv, float *final_mtwv_threshold, float *final_otwv)
Definition: kws-scoring.cc:448
#define KALDI_LOG
Definition: kaldi-error.h:153
void Register(OptionsItf *opts)
Definition: kws-scoring.cc:113
std::string GetOptArg(int param) const
KwsAlignment AlignTerms()
Definition: kws-scoring.cc:125