kws-scoring.h
Go to the documentation of this file.
1 // kws/kws-scoring.h
2 
3 // Copyright (c) 2015, Johns Hopkins University (Yenda Trmal<jtrmal@gmail.com>)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 #ifndef KALDI_KWS_KWS_SCORING_H_
20 #define KALDI_KWS_KWS_SCORING_H_
21 
22 #include <vector>
23 #include <list>
24 #include <utility>
25 #include <string>
26 
27 #include "util/common-utils.h"
28 #include "util/stl-utils.h"
29 
30 namespace kaldi {
31 
32 class KwsTerm {
33  public:
35  utt_id_(0),
36  kw_id_(""),
37  start_time_(0),
38  end_time_(0),
39  score_(0)
40  { }
41 
42  // A convenience function to instantiate the object
43  // from the entries from the results files (generated by kws-search)
44  // In longer term, should be replaced by Read/Write functions
45  KwsTerm(const std::string &kw_id, const std::vector<double> &vec) {
46  set_kw_id(kw_id);
47 
48  KALDI_ASSERT(vec.size() == 4);
49 
50  set_utt_id(vec[0]);
51  set_start_time(vec[1]);
52  set_end_time(vec[2]);
53  set_score(vec[3]);
54  }
55 
56  inline bool valid() const {
57  return (kw_id_ != "");
58  }
59 
60  // Attribute accessors/mutators
61  inline int utt_id() const {return utt_id_;}
62  inline void set_utt_id(int utt_id) {utt_id_ = utt_id;}
63  inline std::string kw_id() const {return kw_id_;}
64  inline void set_kw_id(const std::string &kw_id) {kw_id_ = kw_id;}
65  inline int start_time() const {return start_time_;}
67  inline int end_time() const {return end_time_;}
68  inline void set_end_time(int end_time) {end_time_ = end_time;}
69  inline float score() const {return score_;}
70  inline void set_score(float score) {score_ = score;}
71 
72  private:
73  int utt_id_;
74  std::string kw_id_;
75  int start_time_; // in frames
76  int end_time_; // in frames
77  float score_;
78 };
79 
80 // Not used, yet
82  kKwsFalseAlarm, // Marked incorrectly as a hit
83  kKwsMiss, // Not marked as hit while it should be
84  kKwsCorr, // Marked correctly as a hit
85  kKwsCorrUndetected, // Not marked as a hit, correctly
86  kKwsUnseen // Instance not seen in the hypotheses list
87 };
88 
89 
94 };
95 
96 // Container class for the ref-hyp pairs
97 class KwsAlignment {
98  friend class KwsTermsAligner;
99  public:
100  // TODO(jtrmal): implement reading/writing CSV
101  // void ReadCsv();
102  void WriteCsv(std::iostream &os, const float frames_per_sec);
103 
104  typedef std::vector<AlignedTermsPair> AlignedTerms;
105 
106  inline AlignedTerms::const_iterator begin() const {return alignment_.begin();}
107  inline AlignedTerms::const_iterator end() const {return alignment_.end();}
108  inline int size() const {return alignment_.size(); }
109 
110  private:
111  // sequence of touples ref, hyp, score
112  // either (in the sense of exlusive OR) of which can be
113  // empty (i.e .valid() will return false)
114  // if ref.valid() == false, then the hyp term does not have
115  // a matching reference
116  // if hyp.valid() == false, then the ref term does not have
117  // a matching reference
118  // Score is the aligned score, i.e.
119  AlignedTerms alignment_;
120 
121  inline void Add(const AlignedTermsPair &next) {
122  alignment_.push_back(next);
123  }
124 };
125 
127  int max_distance; // Maximum distance (in frames) of the centers of
128  // the ref and and the hyp to be considered as a potential
129  // match during alignment process
130  // Default: 50 frames (usually 0.5 seconds)
131 
132  inline KwsTermsAlignerOptions(): max_distance(50) {}
133  void Register(OptionsItf *opts);
134 };
135 
137  public:
138  void AddRef(const KwsTerm &ref) {
139  refs_[ref.utt_id()][ref.kw_id()].push_back(ref);
140  nof_refs_++;
141  }
142 
143  void AddHyp(const KwsTerm &hyp) {
144  hyps_.push_back(hyp);
145  nof_hyps_++;
146  }
147 
148  inline int nof_hyps() const {return nof_hyps_;}
149  inline int nof_refs() const {return nof_refs_;}
150 
151  explicit KwsTermsAligner(const KwsTermsAlignerOptions &opts);
152 
153  // Retrieve the final ref-hyp alignment
154  KwsAlignment AlignTerms();
155 
156  // Score the quality of a match between ref and hyp
157  virtual float AlignerScore(const KwsTerm &ref, const KwsTerm &hyp);
158 
159  private:
160  typedef std::vector<KwsTerm> TermArray;
161  typedef std::vector<KwsTerm>::iterator TermIterator;
162  typedef unordered_map<int, bool> TermUseMap;
163  unordered_map<int, unordered_map<std::string, TermArray > > refs_;
164  unordered_map<int, unordered_map<std::string, TermUseMap > > used_ref_terms_;
165  std::list<KwsTerm> hyps_;
169 
170  // Finds the best (if there is one) ref instance for the
171  // given hyp term. Returns index >= 0 when found, -1 when
172  // not found
173  int FindBestRefIndex(const KwsTerm &term);
174 
175  // Find the next adept for best match to hyp.
176  TermIterator FindNextRef(const KwsTerm &hyp,
177  const TermIterator &prev,
178  const TermIterator &last);
179  // A quick test if it's even worth to attempt to look
180  // for a ref for the given term -- checks if the combination
181  // of utt_id and kw_id exists in the reference.
182  bool RefExistsMaybe(const KwsTerm &term);
183 
184  // Adds all ref entries which weren't matched to any hyp
185  void FillUnmatchedRefs(KwsAlignment *ali);
186 };
187 
189  // The option names are taken from the Babel KWS15 eval plan
190  // http://www.nist.gov/itl/iad/mig/upload/KWS15-evalplan-v05.pdf
191  float cost_fa; // The cost of an incorrect detection;
192  // defined as 0.1
193 
194  float value_corr; // The value of a correct detection;
195  // defined as 1.0
196 
197  float prior_probability; // The prior probability of a keyword;
198  // defined as 1e-4
199 
200  float score_threshold; // The score threshold for computation of ATWV
201  // defined as 0.5
202 
203  float sweep_step; // Size of the bin during sweeping for
204  // the oracle measures, 0.05 by default
205 
206  float audio_duration; // Total duration of the audio
207  // This has to be set to a correct value
208  // in seconds, unset by default;
209 
211 
212  inline float beta() const {
213  return (cost_fa/value_corr) * (1.0/prior_probability - 1);
214  }
215 
216  void Register(OptionsItf *opts);
217 };
218 
219 class TwvMetricsStats;
220 
221 class TwvMetrics {
222  public:
223  explicit TwvMetrics(const TwvMetricsOptions &opts);
224  ~TwvMetrics();
225 
226  // Feed the alignment -- can be done several times
227  // so that the statistics will be cumulative
228  void AddAlignment(const KwsAlignment &ali);
229 
230  // Forget the stats
231  void Reset();
232 
233  // Actual Term-Weighted Value
234  float Atwv();
235  // Supreme Term-Weighted Value
236  float Stwv();
237 
238  // Get the MTWV, OTWV and the MTWV threshold
239  // Getting these metrics is computationally intensive
240  // and most of the computations are shared between
241  // getting MTWV and OTWV, so we compute them at he same time
242  void GetOracleMeasures(float *final_mtwv,
243  float *final_mtwv_threshold,
244  float *final_otwv);
245 
246  private:
248 
251  float beta_;
252 
254 
255  void AddEvent(const KwsTerm &ref, const KwsTerm &hyp, float ali_score);
256  void RefAndHypSeen(const std::string &kw_id, float score);
257  void OnlyRefSeen(const std::string &kw_id, float score);
258  void OnlyHypSeen(const std::string &kw_id, float score);
259 };
260 
261 } // namespace kaldi
262 #endif // KALDI_KWS_KWS_SCORING_H_
263 
void AddRef(const KwsTerm &ref)
Definition: kws-scoring.h:138
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
unordered_map< int, bool > TermUseMap
Definition: kws-scoring.h:162
KwsTerm(const std::string &kw_id, const std::vector< double > &vec)
Definition: kws-scoring.h:45
float score() const
Definition: kws-scoring.h:69
void AddHyp(const KwsTerm &hyp)
Definition: kws-scoring.h:143
unordered_map< int, unordered_map< std::string, TermUseMap > > used_ref_terms_
Definition: kws-scoring.h:164
AlignedTerms alignment_
Definition: kws-scoring.h:119
void Add(const AlignedTermsPair &next)
Definition: kws-scoring.h:121
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
std::vector< AlignedTermsPair > AlignedTerms
Definition: kws-scoring.h:104
std::string kw_id() const
Definition: kws-scoring.h:63
DetectionDecision
Definition: kws-scoring.h:81
AlignedTerms::const_iterator begin() const
Definition: kws-scoring.h:106
void set_start_time(int start_time)
Definition: kws-scoring.h:66
std::list< KwsTerm > hyps_
Definition: kws-scoring.h:165
void set_end_time(int end_time)
Definition: kws-scoring.h:68
void set_score(float score)
Definition: kws-scoring.h:70
void set_utt_id(int utt_id)
Definition: kws-scoring.h:62
void set_kw_id(const std::string &kw_id)
Definition: kws-scoring.h:64
int end_time() const
Definition: kws-scoring.h:67
bool valid() const
Definition: kws-scoring.h:56
std::vector< KwsTerm >::iterator TermIterator
Definition: kws-scoring.h:161
unordered_map< int, unordered_map< std::string, TermArray > > refs_
Definition: kws-scoring.h:163
int start_time() const
Definition: kws-scoring.h:65
KwsTermsAlignerOptions opts_
Definition: kws-scoring.h:166
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int size() const
Definition: kws-scoring.h:108
int utt_id() const
Definition: kws-scoring.h:61
std::string kw_id_
Definition: kws-scoring.h:74
TwvMetricsStats * stats_
Definition: kws-scoring.h:253
float atwv_decision_threshold_
Definition: kws-scoring.h:250
AlignedTerms::const_iterator end() const
Definition: kws-scoring.h:107
std::vector< KwsTerm > TermArray
Definition: kws-scoring.h:160