28 namespace kws_internal {
49 max_distance_(max_distance), inst_(inst) {}
67 ret &= fabs(center_left - center_right) <= max_distance_;
73 return (*
this)(inst_, right);
107 typedef unordered_map <std::string, KwScoreStats>
KwStats;
114 opts->
Register(
"max_distance", &max_distance,
115 "Max distance on the ref and hyp centers " 116 "to be considered as a potential match");
129 std::list<KwsTerm>::iterator it =
hyps_.begin();
130 for (; it !=
hyps_.end(); ++it) {
132 ref_hyp_pair.
hyp = *it;
133 ref_hyp_pair.
aligner_score = -std::numeric_limits<float>::infinity();
137 int utt_id = it->utt_id();
138 std::string kw_id = it->kw_id();
140 ref_hyp_pair.
ref =
refs_[utt_id][kw_id][ref_idx];
146 alignment.
Add(ref_hyp_pair);
148 KALDI_LOG <<
"Alignment size before adding unseen: " << alignment.
size();
152 KALDI_LOG <<
"Alignment size after adding unseen: " << alignment.
size();
160 typedef unordered_map<std::string, TermArray> KwList;
161 typedef KwList::iterator KwIndex;
162 typedef unordered_map<int, KwList >::iterator UttIndex;
164 for (UttIndex utt =
refs_.begin(); utt !=
refs_.end(); ++utt) {
165 int utt_id = utt->first;
166 for (KwIndex kw =
refs_[utt_id].begin(); kw !=
refs_[utt_id].end(); ++kw) {
167 std::string kw_id = kw->first;
169 term !=
refs_[utt_id][kw_id].end(); ++term ) {
170 int idx = term -
refs_[utt_id][kw_id].begin();
173 missed_hyp.
aligner_score = -std::numeric_limits<float>::infinity();
174 missed_hyp.
ref =
refs_[utt_id][kw_id][idx];
175 ali->
Add(missed_hyp);
186 int utt_id = term.
utt_id();
187 std::string kw_id = term.
kw_id();
193 if (it == end_mark) {
197 int best_ref_idx = -1;
198 float best_ref_score = -std::numeric_limits<float>::infinity();
201 int current_index = it - start_mark;
202 if ((current_score > best_ref_score) &&
204 best_ref_idx = current_index;
205 best_ref_score = current_score;
209 }
while (it != end_mark);
216 int utt_id = term.
utt_id();
217 std::string kw_id = term.
kw_id();
218 if (
refs_.count(utt_id) != 0) {
232 return std::find_if(prev, last,
241 return static_cast<float>(overlap) / join;
245 AlignedTerms::const_iterator it = begin();
246 os <<
"language,file,channel,termid,term,ref_bt,ref_et," 247 <<
"sys_bt,sys_et,sys_score,sys_decision,alignment\n";
249 while ( it != end() ) {
250 int file = it->ref.valid() ? it->ref.utt_id() : it->hyp.utt_id();
251 std::string termid = it->ref.valid() ? it->ref.kw_id() : it->hyp.kw_id();
252 std::string term = termid;
253 std::string lang =
"";
262 if (it->ref.valid()) {
263 os << it->ref.start_time() /
static_cast<float>(frames_per_sec) <<
"," 264 << it->ref.end_time() /
static_cast<float>(frames_per_sec) <<
",";
268 if (it->hyp.valid()) {
269 os << it->hyp.start_time() /
static_cast<float>(frames_per_sec) <<
"," 270 << it->hyp.end_time() /
static_cast<float>(frames_per_sec) <<
"," 271 << it->hyp.score() <<
"," 272 << (it->hyp.score() >= 0.5 ?
"YES" :
"NO") <<
",";
274 os <<
"," <<
"," <<
"," <<
",";
277 if (it->ref.valid() && it->hyp.valid()) {
278 os << (it->hyp.score() >= 0.5 ?
"CORR" :
"MISS");
279 }
else if (it->ref.valid()) {
281 }
else if (it->hyp.valid()) {
282 os << (it->hyp.score() >= 0.5 ?
"FA" :
"CORR!DET");
292 prior_probability(1e-4f),
293 score_threshold(0.5f),
295 audio_duration(0.0f) {}
299 "The cost of an incorrect detection");
301 "The value (gain) of a correct detection");
303 "The prior probability of a keyword");
305 "The score threshold for computation of ATWV");
307 "Size of the bin during sweeping for the oracle measures");
342 }
else if (hyp.
valid()) {
344 }
else if (ref.
valid()) {
354 float decision_threshold = *
i;
355 if ( score >= decision_threshold )
372 float decision_threshold = *
i;
373 if ( score >= decision_threshold )
393 KwsAlignment::AlignedTerms::const_iterator it = ali.
begin();
395 while (it != ali.
end()) {
396 AddEvent(it->ref, it->hyp, it->aligner_score);
400 KALDI_VLOG(4) <<
"Processed " << k <<
" alignment entries";
409 typedef kws_internal::KwStats::iterator KwIterator;
415 if (it->second.nof_targets == 0) {
418 float nof_targets =
static_cast<float>(it->second.nof_targets);
419 float pmiss = 1 - it->second.nof_corr / nof_targets;
421 float twv = 1 - pmiss -
beta_ * pfa;
423 atwv = atwv * (nof_kw)/(nof_kw + 1.0) + twv / (nof_kw + 1.0);
430 typedef kws_internal::KwStats::iterator KwIterator;
436 if (it->second.nof_targets == 0) {
439 float nof_targets =
static_cast<float>(it->second.nof_targets);
440 float recall = 1 - it->second.nof_unseen / nof_targets;
442 stwv = stwv * (nof_kw)/(nof_kw + 1.0) + recall / (nof_kw + 1.0);
449 float *final_mtwv_threshold,
451 typedef kws_internal::KwStats::iterator KwIterator;
456 unordered_map<float, double> mtwv_sweep;
459 if (it->second.nof_targets == 0) {
462 std::string kw_id = it->first;
464 float local_otwv = -9999;
465 float local_otwv_threshold = -1.0;
468 float decision_threshold = *
i;
470 float nof_targets =
static_cast<float>(it->second.nof_targets);
473 float pmiss = 1 - nof_true / nof_targets;
475 float twv = 1 - pmiss -
beta_ * pfa;
477 if (twv > local_otwv) {
479 local_otwv_threshold = decision_threshold;
481 mtwv_sweep[decision_threshold] = twv / (nof_kw + 1.0) +
482 mtwv_sweep[decision_threshold] * (nof_kw)/(nof_kw + 1.0);
485 otwv = otwv * (nof_kw)/(nof_kw + 1.0) + local_otwv / (nof_kw + 1.0);
490 float mtwv_threshold = -1;
493 float decision_threshold = *
i;
495 if (mtwv_sweep[decision_threshold] > mtwv) {
496 mtwv = mtwv_sweep[decision_threshold];
497 mtwv_threshold = decision_threshold;
502 *final_mtwv_threshold = mtwv_threshold;
KwTermEqual(const int max_distance, const KwsTerm &inst)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
bool operator()(const KwsTerm &left, const KwsTerm &right)
KwsTermsAligner(const KwsTermsAlignerOptions &opts)
bool RefExistsMaybe(const KwsTerm &term)
virtual float AlignerScore(const KwsTerm &ref, const KwsTerm &hyp)
kws_internal::KwScoreStats global_keyword_stats
unordered_map< int, unordered_map< std::string, TermUseMap > > used_ref_terms_
kws_internal::KwStats keyword_stats
void Add(const AlignedTermsPair &next)
void OnlyHypSeen(const std::string &kw_id, float score)
std::string kw_id() const
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
AlignedTerms::const_iterator begin() const
void RefAndHypSeen(const std::string &kw_id, float score)
void FillUnmatchedRefs(KwsAlignment *ali)
void AddEvent(const KwsTerm &ref, const KwsTerm &hyp, float ali_score)
kws_internal::PerKwSweepStats otwv_sweep_cache
unordered_map< float, ThrSweepStats > SweepThresholdStats
std::list< KwsTerm > hyps_
void OnlyRefSeen(const std::string &kw_id, float score)
KwTermLower(const int threshold)
unordered_map< std::string, KwScoreStats > KwStats
void AddAlignment(const KwsAlignment &ali)
std::vector< KwsTerm >::iterator TermIterator
unordered_map< std::string, SweepThresholdStats > PerKwSweepStats
unordered_map< int, unordered_map< std::string, TermArray > > refs_
void WriteCsv(std::iostream &os, const float frames_per_sec)
std::list< float > sweep_threshold_values
KwsTermsAlignerOptions opts_
#define KALDI_ASSERT(cond)
void Register(OptionsItf *opts)
void GetOracleMeasures(float *final_mtwv, float *final_mtwv_threshold, float *final_otwv)
TermIterator FindNextRef(const KwsTerm &hyp, const TermIterator &prev, const TermIterator &last)
float atwv_decision_threshold_
void Register(OptionsItf *opts)
AlignedTerms::const_iterator end() const
TwvMetrics(const TwvMetricsOptions &opts)
int FindBestRefIndex(const KwsTerm &term)
KwsAlignment AlignTerms()