kws-scoring.cc
Go to the documentation of this file.
1 // kws/kws-scoring.cc
2 
3 // Copyright (c) 2015, Johns Hopkins University (Yenda Trmal<jtrmal@gmail.com>)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <utility>
21 #include <vector>
22 #include <limits>
23 #include <algorithm>
24 
25 #include "kws/kws-scoring.h"
26 
27 namespace kaldi {
28 namespace kws_internal {
29 
30 class KwTermLower {
31  public:
32  explicit KwTermLower(const int threshold): threshold_(threshold) {}
33 
34  bool operator() (const KwsTerm &left, const KwsTerm &right) {
35  if ( (left.start_time() + threshold_) < right.start_time() ) {
36  return true;
37  } else {
38  return (left.end_time() + threshold_) < right.end_time();
39  }
40  }
41 
42  private:
43  const int threshold_;
44 };
45 
46 class KwTermEqual {
47  public:
48  KwTermEqual(const int max_distance, const KwsTerm &inst):
49  max_distance_(max_distance), inst_(inst) {}
50 
51  bool operator() (const KwsTerm &left, const KwsTerm &right) {
52  bool ret = true;
53 
54  ret &= (left.kw_id() == right.kw_id());
55  ret &= (left.utt_id() == right.utt_id());
56 
57  float center_left = (left.start_time() + left.end_time())/2;
58  float center_right = (right.start_time() + right.end_time())/2;
59 
60 // This was an old definition of the criterion "the hyp is within
61 // max_distance_ area from the ref". The positive thing about the
62 // definition is, that it allows binary search through the collection
63 // ret &= fabs(left.tbeg - right.tbeg) <= max_distance_;
64 // ret &= fabs(left.tend - right.tend) <= max_distance_;
65 
66 // This is the newer definition -- should be equivalent to what F4DE uses
67  ret &= fabs(center_left - center_right) <= max_distance_;
68 
69  return ret;
70  }
71 
72  bool operator() (const KwsTerm &right) {
73  return (*this)(inst_, right);
74  }
75 
76  private:
77  const int max_distance_;
78  const KwsTerm inst_;
79 };
80 
81 
82 struct KwScoreStats {
89 
90  KwScoreStats(): nof_corr(0),
91  nof_fa(0),
92  nof_misses(0),
93  nof_corr_ndet(0),
94  nof_unseen(0),
95  nof_targets(0) {}
96 };
97 
98 struct ThrSweepStats {
101 
102  ThrSweepStats(): nof_corr(0),
103  nof_fa(0) {}
104 };
105 
106 typedef unordered_map <float, ThrSweepStats> SweepThresholdStats;
107 typedef unordered_map <std::string, KwScoreStats> KwStats;
108 typedef unordered_map <std::string, SweepThresholdStats> PerKwSweepStats;
109 
110 } // namespace kws_internal
111 
112 
114  opts->Register("max_distance", &max_distance,
115  "Max distance on the ref and hyp centers "
116  "to be considered as a potential match");
117 }
118 
120  opts_(opts),
121  nof_refs_(0),
122  nof_hyps_(0) { }
123 
124 
126  KwsAlignment alignment;
127 
128  used_ref_terms_.clear();
129  std::list<KwsTerm>::iterator it = hyps_.begin();
130  for (; it != hyps_.end(); ++it) {
131  AlignedTermsPair ref_hyp_pair;
132  ref_hyp_pair.hyp = *it;
133  ref_hyp_pair.aligner_score = -std::numeric_limits<float>::infinity();
134 
135  int ref_idx = FindBestRefIndex(*it);
136  if (ref_idx >= 0) { // If found
137  int utt_id = it->utt_id();
138  std::string kw_id = it->kw_id();
139 
140  ref_hyp_pair.ref = refs_[utt_id][kw_id][ref_idx];
141  used_ref_terms_[utt_id][kw_id][ref_idx] = true;
142  ref_hyp_pair.aligner_score = AlignerScore(ref_hyp_pair.ref,
143  ref_hyp_pair.hyp);
144  }
145 
146  alignment.Add(ref_hyp_pair);
147  }
148  KALDI_LOG << "Alignment size before adding unseen: " << alignment.size();
149  // Finally, find the terms in ref which have not been seen in hyp
150  // and add them into the alignment
151  FillUnmatchedRefs(&alignment);
152  KALDI_LOG << "Alignment size after adding unseen: " << alignment.size();
153  return alignment;
154 }
155 
157  // We have to traverse the whole ref_ structure and check
158  // against the used_ref_terms_ structure if the given ref term
159  // was already used or not. If not, we will add it to the alignment
160  typedef unordered_map<std::string, TermArray> KwList;
161  typedef KwList::iterator KwIndex;
162  typedef unordered_map<int, KwList >::iterator UttIndex;
163 
164  for (UttIndex utt = refs_.begin(); utt != refs_.end(); ++utt) {
165  int utt_id = utt->first;
166  for (KwIndex kw = refs_[utt_id].begin(); kw != refs_[utt_id].end(); ++kw) {
167  std::string kw_id = kw->first;
168  for (TermIterator term = refs_[utt_id][kw_id].begin();
169  term != refs_[utt_id][kw_id].end(); ++term ) {
170  int idx = term - refs_[utt_id][kw_id].begin();
171  if (!used_ref_terms_[utt_id][kw_id][idx]) {
172  AlignedTermsPair missed_hyp;
173  missed_hyp.aligner_score = -std::numeric_limits<float>::infinity();
174  missed_hyp.ref = refs_[utt_id][kw_id][idx];
175  ali->Add(missed_hyp);
176  }
177  }
178  }
179  }
180 }
181 
183  if (!RefExistsMaybe(term)) {
184  return -1;
185  }
186  int utt_id = term.utt_id();
187  std::string kw_id = term.kw_id();
188 
189  TermIterator start_mark = refs_[utt_id][kw_id].begin();
190  TermIterator end_mark = refs_[utt_id][kw_id].end();
191 
192  TermIterator it = FindNextRef(term, start_mark, end_mark);
193  if (it == end_mark) {
194  return -1;
195  }
196 
197  int best_ref_idx = -1;
198  float best_ref_score = -std::numeric_limits<float>::infinity();
199  do {
200  float current_score = AlignerScore(*it, term);
201  int current_index = it - start_mark;
202  if ((current_score > best_ref_score) &&
203  (!used_ref_terms_[utt_id][kw_id][current_index])) {
204  best_ref_idx = current_index;
205  best_ref_score = current_score;
206  }
207 
208  it = FindNextRef(term, ++it, end_mark);
209  } while (it != end_mark);
210 
211  return best_ref_idx;
212 }
213 
214 
216  int utt_id = term.utt_id();
217  std::string kw_id = term.kw_id();
218  if (refs_.count(utt_id) != 0) {
219  if (refs_[utt_id].count(kw_id) != 0) {
220  return true;
221  }
222  }
223  return false;
224 }
225 
226 
227 
229  const KwsTerm &ref,
230  const TermIterator &prev,
231  const TermIterator &last) {
232  return std::find_if(prev, last,
234 }
235 
236 float KwsTermsAligner::AlignerScore(const KwsTerm &ref, const KwsTerm &hyp) {
237  float overlap = std::min(ref.end_time(), hyp.end_time())
238  - std::max(ref.start_time(), hyp.start_time());
239  float join = std::max(ref.end_time(), hyp.end_time())
240  - std::min(ref.start_time(), hyp.start_time());
241  return static_cast<float>(overlap) / join;
242 }
243 
244 void KwsAlignment::WriteCsv(std::iostream &os, const float frames_per_sec) {
245  AlignedTerms::const_iterator it = begin();
246  os << "language,file,channel,termid,term,ref_bt,ref_et,"
247  << "sys_bt,sys_et,sys_score,sys_decision,alignment\n";
248 
249  while ( it != end() ) {
250  int file = it->ref.valid() ? it->ref.utt_id() : it->hyp.utt_id();
251  std::string termid = it->ref.valid() ? it->ref.kw_id() : it->hyp.kw_id();
252  std::string term = termid;
253  std::string lang = "";
254  int channel = 1;
255 
256  os << lang << ","
257  << file << ","
258  << channel << ","
259  << termid << ","
260  << term << ",";
261 
262  if (it->ref.valid()) {
263  os << it->ref.start_time() / static_cast<float>(frames_per_sec) << ","
264  << it->ref.end_time() / static_cast<float>(frames_per_sec) << ",";
265  } else {
266  os << "," << ",";
267  }
268  if (it->hyp.valid()) {
269  os << it->hyp.start_time() / static_cast<float>(frames_per_sec) << ","
270  << it->hyp.end_time() / static_cast<float>(frames_per_sec) << ","
271  << it->hyp.score() << ","
272  << (it->hyp.score() >= 0.5 ? "YES" : "NO") << ",";
273  } else {
274  os << "," << "," << "," << ",";
275  }
276 
277  if (it->ref.valid() && it->hyp.valid()) {
278  os << (it->hyp.score() >= 0.5 ? "CORR" : "MISS");
279  } else if (it->ref.valid()) {
280  os << "MISS";
281  } else if (it->hyp.valid()) {
282  os << (it->hyp.score() >= 0.5 ? "FA" : "CORR!DET");
283  }
284  os << std::endl;
285  it++;
286  }
287 }
288 
289 
291  value_corr(1.0f),
292  prior_probability(1e-4f),
293  score_threshold(0.5f),
294  sweep_step(0.05f),
295  audio_duration(0.0f) {}
296 
298  opts->Register("cost-fa", &cost_fa,
299  "The cost of an incorrect detection");
300  opts->Register("value-corr", &value_corr,
301  "The value (gain) of a correct detection");
302  opts->Register("prior-kw-probability", &prior_probability,
303  "The prior probability of a keyword");
304  opts->Register("score-threshold", &score_threshold,
305  "The score threshold for computation of ATWV");
306  opts->Register("sweep-step", &sweep_step,
307  "Size of the bin during sweeping for the oracle measures");
308 
309  // We won't set the audio duration here, as it's supposed to be
310  // a mandatory argument, not optional
311 }
312 
314  public:
318  std::list<float> sweep_threshold_values;
319 };
320 
322  audio_duration_(opts.audio_duration),
323  atwv_decision_threshold_(opts.score_threshold),
324  beta_(opts.beta()) {
325  stats_ = new TwvMetricsStats();
326  if (opts.sweep_step > 0.0) {
327  for (float i=0.0; i <= 1; i+=opts.sweep_step) {
328  stats_->sweep_threshold_values.push_back(i);
329  }
330  }
331 }
332 
334  delete stats_;
335 }
336 
338  const KwsTerm &hyp,
339  float ali_score) {
340  if (ref.valid() && hyp.valid()) {
341  RefAndHypSeen(hyp.kw_id(), hyp.score());
342  } else if (hyp.valid()) {
343  OnlyHypSeen(hyp.kw_id(), hyp.score());
344  } else if (ref.valid()) {
345  OnlyRefSeen(ref.kw_id(), ref.score());
346  } else {
347  KALDI_ASSERT(ref.valid() || hyp.valid());
348  }
349 }
350 
351 void TwvMetrics::RefAndHypSeen(const std::string &kw_id, float score) {
352  std::list<float>::iterator i = stats_->sweep_threshold_values.begin();
353  for (; i != stats_->sweep_threshold_values.end(); ++i) {
354  float decision_threshold = *i;
355  if ( score >= decision_threshold )
356  stats_->otwv_sweep_cache[kw_id][decision_threshold].nof_corr++;
357  }
358  if (score >= atwv_decision_threshold_) {
360  stats_->keyword_stats[kw_id].nof_corr++;
361  } else {
363  stats_->keyword_stats[kw_id].nof_misses++;
364  }
366  stats_->keyword_stats[kw_id].nof_targets++;
367 }
368 
369 void TwvMetrics::OnlyHypSeen(const std::string &kw_id, float score) {
370  std::list<float>::iterator i = stats_->sweep_threshold_values.begin();
371  for (; i != stats_->sweep_threshold_values.end(); ++i) {
372  float decision_threshold = *i;
373  if ( score >= decision_threshold )
374  stats_->otwv_sweep_cache[kw_id][decision_threshold].nof_fa++;
375  }
376  if (score >= atwv_decision_threshold_) {
378  stats_->keyword_stats[kw_id].nof_fa++;
379  } else {
381  stats_->keyword_stats[kw_id].nof_corr_ndet++;
382  }
383 }
384 
385 void TwvMetrics::OnlyRefSeen(const std::string &kw_id, float score) {
387  stats_->keyword_stats[kw_id].nof_targets++;
389  stats_->keyword_stats[kw_id].nof_unseen++;
390 }
391 
393  KwsAlignment::AlignedTerms::const_iterator it = ali.begin();
394  int k = 0;
395  while (it != ali.end()) {
396  AddEvent(it->ref, it->hyp, it->aligner_score);
397  ++it;
398  ++k;
399  }
400  KALDI_VLOG(4) << "Processed " << k << " alignment entries";
401 }
402 
404  delete stats_;
405  stats_ = new TwvMetricsStats;
406 }
407 
409  typedef kws_internal::KwStats::iterator KwIterator;
410  int32 nof_kw = 0;
411  float atwv = 0;
412 
413  for (KwIterator it = stats_->keyword_stats.begin();
414  it != stats_->keyword_stats.end(); ++it ) {
415  if (it->second.nof_targets == 0) {
416  continue;
417  }
418  float nof_targets = static_cast<float>(it->second.nof_targets);
419  float pmiss = 1 - it->second.nof_corr / nof_targets;
420  float pfa = it->second.nof_fa / (audio_duration_ - nof_targets);
421  float twv = 1 - pmiss - beta_ * pfa;
422 
423  atwv = atwv * (nof_kw)/(nof_kw + 1.0) + twv / (nof_kw + 1.0);
424  nof_kw++;
425  }
426  return atwv;
427 }
428 
430  typedef kws_internal::KwStats::iterator KwIterator;
431  int32 nof_kw = 0;
432  float stwv = 0;
433 
434  for (KwIterator it = stats_->keyword_stats.begin();
435  it != stats_->keyword_stats.end(); ++it ) {
436  if (it->second.nof_targets == 0) {
437  continue;
438  }
439  float nof_targets = static_cast<float>(it->second.nof_targets);
440  float recall = 1 - it->second.nof_unseen / nof_targets;
441 
442  stwv = stwv * (nof_kw)/(nof_kw + 1.0) + recall / (nof_kw + 1.0);
443  nof_kw++;
444  }
445  return stwv;
446 }
447 
448 void TwvMetrics::GetOracleMeasures(float *final_mtwv,
449  float *final_mtwv_threshold,
450  float *final_otwv) {
451  typedef kws_internal::KwStats::iterator KwIterator;
452 
453  int32 nof_kw = 0;
454  float otwv = 0;
455 
456  unordered_map<float, double> mtwv_sweep;
457  for (KwIterator it = stats_->keyword_stats.begin();
458  it != stats_->keyword_stats.end(); ++it ) {
459  if (it->second.nof_targets == 0) {
460  continue;
461  }
462  std::string kw_id = it->first;
463 
464  float local_otwv = -9999;
465  float local_otwv_threshold = -1.0;
466  std::list<float>::iterator i = stats_->sweep_threshold_values.begin();
467  for (; i != stats_->sweep_threshold_values.end(); ++i) {
468  float decision_threshold = *i;
469 
470  float nof_targets = static_cast<float>(it->second.nof_targets);
471  float nof_true = stats_->otwv_sweep_cache[kw_id][decision_threshold].nof_corr;
472  float nof_fa = stats_->otwv_sweep_cache[kw_id][decision_threshold].nof_fa;
473  float pmiss = 1 - nof_true / nof_targets;
474  float pfa = nof_fa / (audio_duration_ - nof_targets);
475  float twv = 1 - pmiss - beta_ * pfa;
476 
477  if (twv > local_otwv) {
478  local_otwv = twv;
479  local_otwv_threshold = decision_threshold;
480  }
481  mtwv_sweep[decision_threshold] = twv / (nof_kw + 1.0) +
482  mtwv_sweep[decision_threshold] * (nof_kw)/(nof_kw + 1.0);
483  }
484  KALDI_ASSERT(local_otwv_threshold >= 0);
485  otwv = otwv * (nof_kw)/(nof_kw + 1.0) + local_otwv / (nof_kw + 1.0);
486  nof_kw++;
487  }
488 
489  float mtwv = -9999;
490  float mtwv_threshold = -1;
491  std::list<float>::iterator i = stats_->sweep_threshold_values.begin();
492  for (; i != stats_->sweep_threshold_values.end(); ++i) {
493  float decision_threshold = *i;
494 
495  if (mtwv_sweep[decision_threshold] > mtwv) {
496  mtwv = mtwv_sweep[decision_threshold];
497  mtwv_threshold = decision_threshold;
498  }
499  }
500  KALDI_ASSERT(mtwv_threshold >= 0);
501  *final_mtwv = mtwv;
502  *final_mtwv_threshold = mtwv_threshold;
503  *final_otwv = otwv;
504 }
505 } // namespace kaldi
506 
507 
KwTermEqual(const int max_distance, const KwsTerm &inst)
Definition: kws-scoring.cc:48
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool operator()(const KwsTerm &left, const KwsTerm &right)
Definition: kws-scoring.cc:34
KwsTermsAligner(const KwsTermsAlignerOptions &opts)
Definition: kws-scoring.cc:119
bool RefExistsMaybe(const KwsTerm &term)
Definition: kws-scoring.cc:215
virtual float AlignerScore(const KwsTerm &ref, const KwsTerm &hyp)
Definition: kws-scoring.cc:236
kws_internal::KwScoreStats global_keyword_stats
Definition: kws-scoring.cc:315
float score() const
Definition: kws-scoring.h:69
unordered_map< int, unordered_map< std::string, TermUseMap > > used_ref_terms_
Definition: kws-scoring.h:164
kws_internal::KwStats keyword_stats
Definition: kws-scoring.cc:316
void Add(const AlignedTermsPair &next)
Definition: kws-scoring.h:121
kaldi::int32 int32
void OnlyHypSeen(const std::string &kw_id, float score)
Definition: kws-scoring.cc:369
std::string kw_id() const
Definition: kws-scoring.h:63
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
AlignedTerms::const_iterator begin() const
Definition: kws-scoring.h:106
void RefAndHypSeen(const std::string &kw_id, float score)
Definition: kws-scoring.cc:351
void FillUnmatchedRefs(KwsAlignment *ali)
Definition: kws-scoring.cc:156
void AddEvent(const KwsTerm &ref, const KwsTerm &hyp, float ali_score)
Definition: kws-scoring.cc:337
kws_internal::PerKwSweepStats otwv_sweep_cache
Definition: kws-scoring.cc:317
const size_t count
unordered_map< float, ThrSweepStats > SweepThresholdStats
Definition: kws-scoring.cc:106
std::list< KwsTerm > hyps_
Definition: kws-scoring.h:165
void OnlyRefSeen(const std::string &kw_id, float score)
Definition: kws-scoring.cc:385
KwTermLower(const int threshold)
Definition: kws-scoring.cc:32
int end_time() const
Definition: kws-scoring.h:67
bool valid() const
Definition: kws-scoring.h:56
unordered_map< std::string, KwScoreStats > KwStats
Definition: kws-scoring.cc:107
void AddAlignment(const KwsAlignment &ali)
Definition: kws-scoring.cc:392
std::vector< KwsTerm >::iterator TermIterator
Definition: kws-scoring.h:161
unordered_map< std::string, SweepThresholdStats > PerKwSweepStats
Definition: kws-scoring.cc:108
unordered_map< int, unordered_map< std::string, TermArray > > refs_
Definition: kws-scoring.h:163
int start_time() const
Definition: kws-scoring.h:65
void WriteCsv(std::iostream &os, const float frames_per_sec)
Definition: kws-scoring.cc:244
std::list< float > sweep_threshold_values
Definition: kws-scoring.cc:318
KwsTermsAlignerOptions opts_
Definition: kws-scoring.h:166
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Register(OptionsItf *opts)
Definition: kws-scoring.cc:297
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
int size() const
Definition: kws-scoring.h:108
void GetOracleMeasures(float *final_mtwv, float *final_mtwv_threshold, float *final_otwv)
Definition: kws-scoring.cc:448
int utt_id() const
Definition: kws-scoring.h:61
TwvMetricsStats * stats_
Definition: kws-scoring.h:253
#define KALDI_LOG
Definition: kaldi-error.h:153
TermIterator FindNextRef(const KwsTerm &hyp, const TermIterator &prev, const TermIterator &last)
Definition: kws-scoring.cc:228
float atwv_decision_threshold_
Definition: kws-scoring.h:250
void Register(OptionsItf *opts)
Definition: kws-scoring.cc:113
AlignedTerms::const_iterator end() const
Definition: kws-scoring.h:107
TwvMetrics(const TwvMetricsOptions &opts)
Definition: kws-scoring.cc:321
int FindBestRefIndex(const KwsTerm &term)
Definition: kws-scoring.cc:182
KwsAlignment AlignTerms()
Definition: kws-scoring.cc:125