online-ivector-feature.h
Go to the documentation of this file.
1 // online2/online-ivector-feature.h
2 
3 // Copyright 2013-2014 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #ifndef KALDI_ONLINE2_ONLINE_IVECTOR_FEATURE_H_
22 #define KALDI_ONLINE2_ONLINE_IVECTOR_FEATURE_H_
23 
24 #include <string>
25 #include <vector>
26 #include <deque>
27 
28 #include "matrix/matrix-lib.h"
29 #include "util/common-utils.h"
30 #include "base/kaldi-error.h"
31 #include "itf/online-feature-itf.h"
32 #include "gmm/diag-gmm.h"
33 #include "feat/online-feature.h"
37 
38 namespace kaldi {
41 
45 
56  std::string lda_mat_rxfilename; // to read the LDA+MLLT matrix
57  std::string global_cmvn_stats_rxfilename; // to read matrix of global CMVN
58  // stats
59  std::string splice_config_rxfilename; // to read OnlineSpliceOptions
60  std::string cmvn_config_rxfilename; // to read in OnlineCmvnOptions
61  bool online_cmvn_iextractor; // flag activating online-cmvn in iextractor
62  // feature pipeline
63  std::string diag_ubm_rxfilename; // reads type DiagGmm.
64  std::string ivector_extractor_rxfilename; // reads type IvectorExtractor
65 
66  // the following four configuration values should in principle match those
67  // given to the script extract_ivectors_online.sh, although none of them are
68  // super-critical.
69  int32 ivector_period; // How frequently we re-estimate iVectors.
70  int32 num_gselect; // maximum number of posteriors to use per frame for
71  // iVector extractor.
72  BaseFloat min_post; // pruning threshold for posteriors for the iVector
73  // extractor.
74  BaseFloat posterior_scale; // Scale on posteriors used for iVector
75  // extraction; can be interpreted as the inverse
76  // of a scale on the log-prior.
77  BaseFloat max_count; // Maximum stats count we allow before we start scaling
78  // down stats (if nonzero).. this prevents us getting
79  // atypical-looking iVectors for very long utterances.
80  // Interpret this as a number of frames times
81  // posterior_scale, typically 1/10 of a frame count.
82 
83  int32 num_cg_iters; // set to 15. I don't believe this is very important, so it's
84  // not configurable from the command line for now.
85 
86 
87  // If use_most_recent_ivector is true, we always return the most recent
88  // available iVector rather than the one for the current frame. This means
89  // that if audio is coming in faster than we can process it, we will return a
90  // more accurate iVector.
92 
93  // If true, always read ahead to NumFramesReady() when getting iVector stats.
95 
96  // max_remembered_frames is the largest number of frames it will remember
97  // between utterances of the same speaker; this affects the output of
98  // GetAdaptationState(), and has the effect of limiting the number of frames
99  // of both the CMVN stats and the iVector stats. Setting this to a smaller
100  // value means the adaptation is less constrained by previous utterances
101  // (assuming you provided info from a previous utterance of the same speaker
102  // by calling SetAdaptationState()).
104 
105  OnlineIvectorExtractionConfig(): online_cmvn_iextractor(false),
106  ivector_period(10), num_gselect(5),
107  min_post(0.025), posterior_scale(0.1),
108  max_count(0.0), num_cg_iters(15),
109  use_most_recent_ivector(true),
110  greedy_ivector_extractor(false),
111  max_remembered_frames(1000) { }
112 
113  void Register(OptionsItf *opts) {
114  opts->Register("lda-matrix", &lda_mat_rxfilename, "Filename of LDA matrix, "
115  "e.g. final.mat; used for iVector extraction. ");
116  opts->Register("global-cmvn-stats", &global_cmvn_stats_rxfilename,
117  "(Extended) filename for global CMVN stats, used in iVector "
118  "extraction, obtained for example from "
119  "'matrix-sum scp:data/train/cmvn.scp -', only used for "
120  "iVector extraction");
121  opts->Register("cmvn-config", &cmvn_config_rxfilename, "Configuration "
122  "file for online CMVN features (e.g. conf/online_cmvn.conf),"
123  "only used for iVector extraction. Contains options "
124  "as for the program 'apply-cmvn-online'");
125  opts->Register("online-cmvn-iextractor", &online_cmvn_iextractor,
126  "add online-cmvn to feature pipeline of ivector extractor, "
127  "use the cmvn setup from the UBM. Note: the default of "
128  "false is what we historically used; we'd use true if "
129  "we were using CMVN'ed features for the neural net.");
130  opts->Register("splice-config", &splice_config_rxfilename, "Configuration file "
131  "for frame splicing (--left-context and --right-context "
132  "options); used for iVector extraction.");
133  opts->Register("diag-ubm", &diag_ubm_rxfilename, "Filename of diagonal UBM "
134  "used to obtain posteriors for iVector extraction, e.g. "
135  "final.dubm");
136  opts->Register("ivector-extractor", &ivector_extractor_rxfilename,
137  "Filename of iVector extractor, e.g. final.ie");
138  opts->Register("ivector-period", &ivector_period, "Frequency with which "
139  "we extract iVectors for neural network adaptation");
140  opts->Register("num-gselect", &num_gselect, "Number of Gaussians to select "
141  "for iVector extraction");
142  opts->Register("min-post", &min_post, "Threshold for posterior pruning in "
143  "iVector extraction");
144  opts->Register("posterior-scale", &posterior_scale, "Scale for posteriors in "
145  "iVector extraction (may be viewed as inverse of prior scale)");
146  opts->Register("max-count", &max_count, "Maximum data count we allow before "
147  "we start scaling the stats down (if nonzero)... helps to make "
148  "iVectors from long utterances look more typical. Interpret "
149  "as a frame-count times --posterior-scale, typically 1/10 of "
150  "a number of frames. Suggest 100.");
151  opts->Register("use-most-recent-ivector", &use_most_recent_ivector, "If true, "
152  "always use most recent available iVector, rather than the "
153  "one for the designated frame.");
154  opts->Register("greedy-ivector-extractor", &greedy_ivector_extractor, "If "
155  "true, 'read ahead' as many frames as we currently have available "
156  "when extracting the iVector. May improve iVector quality.");
157  opts->Register("max-remembered-frames", &max_remembered_frames, "The maximum "
158  "number of frames of adaptation history that we carry through "
159  "to later utterances of the same speaker (having a finite "
160  "number allows the speaker adaptation state to change over "
161  "time). Interpret as a real frame count, i.e. not a count "
162  "scaled by --posterior-scale.");
163  }
164 };
165 
169 
170  Matrix<BaseFloat> lda_mat; // LDA+MLLT matrix.
171  Matrix<double> global_cmvn_stats; // Global CMVN stats.
172 
173  OnlineCmvnOptions cmvn_opts; // Options for online CMN/CMVN computation.
174  bool online_cmvn_iextractor; // flag activating online CMN/CMVN for iextractor input.
175  OnlineSpliceOptions splice_opts; // Options for frame splicing
176  // (--left-context,--right-context)
177 
180 
181  // the following configuration variables are copied from
182  // OnlineIvectorExtractionConfig, see comments there.
192 
194 
195  void Init(const OnlineIvectorExtractionConfig &config);
196 
197  int32 ExpectedFeatureDim() const;
198 
199  // This constructor creates a version of this object where everything
200  // is empty or zero.
202 
203  void Check() const;
204  private:
206 };
207 
212  // CMVN state for the features used to get posteriors for iVector extraction;
213  // online CMVN is not used for the features supplied to the neural net,
214  // instead the iVector is used.
215 
216  // Adaptation state for online CMVN (used for getting posteriors for iVector)
218 
221 
224  cmvn_state(info.global_cmvn_stats),
225  ivector_stats(info.extractor.IvectorDim(),
226  info.extractor.PriorOffset(),
227  info.max_count) { }
228 
232 
239  void LimitFrames(BaseFloat max_remembered_frames,
241 
242  void Write(std::ostream &os, bool binary) const;
243  void Read(std::istream &is, bool binary);
244 };
245 
246 
247 
248 
255 
257  public:
265  OnlineFeatureInterface *base_feature);
266 
267  // This version of the constructor accepts per-frame weights (relates to
268  // downweighting silence). This is intended for use in offline operation,
269  // i.e. during training. [will implement this when needed.]
270  //explicit OnlineIvectorFeature(const OnlineIvectorExtractionInfo &info,
271  // std::vector<BaseFloat> frame_weights,
272  //OnlineFeatureInterface *base_feature);
273 
274 
275  // Member functions from OnlineFeatureInterface:
276 
278  virtual int32 Dim() const;
279  virtual bool IsLastFrame(int32 frame) const;
280  virtual int32 NumFramesReady() const;
281  virtual BaseFloat FrameShiftInSeconds() const;
282  virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat);
283 
287  void SetAdaptationState(
288  const OnlineIvectorExtractorAdaptationState &adaptation_state);
289 
290 
294  void GetAdaptationState(
295  OnlineIvectorExtractorAdaptationState *adaptation_state) const;
296 
297  virtual ~OnlineIvectorFeature();
298 
299  // Some diagnostics (not present in generic interface):
300  // UBM log-like per frame:
301  BaseFloat UbmLogLikePerFrame() const;
302  // Objective improvement per frame from iVector estimation, versus default iVector
303  // value, measured at utterance end.
304  BaseFloat ObjfImprPerFrame() const;
305 
306  // returns number of frames seen (but not counting the posterior-scale).
308  return ivector_stats_.NumFrames() / info_.posterior_scale;
309  }
310 
311 
312  // If you are downweighting silence, you can call
313  // OnlineSilenceWeighting::GetDeltaWeights and supply the output to this class
314  // using UpdateFrameWeights(). The reason why this call happens outside this
315  // class, rather than this class pulling in the data weights, relates to
316  // multi-threaded operation and also from not wanting this class to have
317  // excessive dependencies.
318  //
319  // You must either always call this as soon as new data becomes available
320  // (ideally just after calling AcceptWaveform), or never call it for the
321  // lifetime of this object.
322  void UpdateFrameWeights(
323  const std::vector<std::pair<int32, BaseFloat> > &delta_weights);
324 
325  private:
326 
327  // This accumulates i-vector stats for a set of frames, specified as pairs
328  // (t, weight). The weights do not have to be positive. (In the online
329  // silence-weighting that we do, negative weights can occur if we change our
330  // minds about the assignment of a frame as silence vs. non-silence).
331  void UpdateStatsForFrames(
332  const std::vector<std::pair<int32, BaseFloat> > &frame_weights);
333 
334  // Returns a modified version of info_.min_post, which is opts_.min_post if
335  // weight is 1.0 or -1.0, but gets larger if fabs(weight) is small... but no
336  // larger than 0.99. (This is an efficiency thing, to not bother processing
337  // very small counts).
338  BaseFloat GetMinPost(BaseFloat weight) const;
339 
340  // This is the original UpdateStatsUntilFrame that is called when there is
341  // no data-weighting involved.
342  void UpdateStatsUntilFrame(int32 frame);
343 
344  // This is the new UpdateStatsUntilFrame that is called when there is
345  // data-weighting (i.e. when the user has been calling UpdateFrameWeights()).
346  void UpdateStatsUntilFrameWeighted(int32 frame);
347 
348  void PrintDiagnostics() const;
349 
351 
352  OnlineFeatureInterface *base_; // The feature this is built on top of
353  // (e.g. MFCC); not owned here
354 
355  OnlineFeatureInterface *lda_; // LDA on top of raw+splice features.
356  OnlineCmvn *cmvn_; // the CMVN that we give to the lda_normalized_.
357  OnlineFeatureInterface *lda_normalized_; // LDA on top of CMVN+splice
358 
359  // the following is the pointers to OnlineFeatureInterface objects that are
360  // owned here and which we need to delete.
361  std::vector<OnlineFeatureInterface*> to_delete_;
362 
365 
373 
380  std::priority_queue<std::pair<int32, BaseFloat>,
381  std::vector<std::pair<int32, BaseFloat> >,
382  std::greater<std::pair<int32, BaseFloat> > > delta_weights_;
383 
385  std::vector<BaseFloat> current_frame_weight_debug_;
386 
393 
397 
400 
405 
411  std::vector<Vector<BaseFloat>* > ivectors_history_;
412 
413 };
414 
415 
417  std::string silence_phones_str;
418  // The weighting factor that we apply to silence phones in the iVector
419  // extraction. This option is only relevant if the --silence-phones option is
420  // set.
422 
423  // Transition-ids that get repeated at least this many times (if
424  // max_state_duration > 0) are treated as silence.
426 
427  // This is the scale that we apply to data that we don't yet have a decoder
428  // traceback for, in the online silence
430 
431  bool Active() const {
432  return !silence_phones_str.empty() && silence_weight != 1.0;
433  }
434 
436  silence_weight(1.0), max_state_duration(-1) { }
437 
438  void Register(OptionsItf *opts) {
439  opts->Register("silence-phones", &silence_phones_str, "(RE weighting in "
440  "iVector estimation for online decoding) List of integer ids of "
441  "silence phones, separated by colons (or commas). Data that "
442  "(according to the traceback of the decoder) corresponds to "
443  "these phones will be downweighted by --silence-weight.");
444  opts->Register("silence-weight", &silence_weight, "(RE weighting in "
445  "iVector estimation for online decoding) Weighting factor for frames "
446  "that the decoder trace-back identifies as silence; only "
447  "relevant if the --silence-phones option is set.");
448  opts->Register("max-state-duration", &max_state_duration, "(RE weighting in "
449  "iVector estimation for online decoding) Maximum allowed "
450  "duration of a single transition-id; runs with durations longer "
451  "than this will be weighted down to the silence-weight.");
452  }
453  // e.g. prefix = "ivector-silence-weighting"
454  void RegisterWithPrefix(std::string prefix, OptionsItf *opts) {
455  ParseOptions po_prefix(prefix, opts);
456  this->Register(&po_prefix);
457  }
458 };
459 
460 // This class is responsible for keeping track of the best-path traceback from
461 // the decoder (efficiently) and computing a weighting of the data based on the
462 // classification of frames as silence (or not silence)... also with a duration
463 // limitation, so data from a very long run of the same transition-id will get
464 // weighted down. (this is often associated with misrecognition or silence).
466  public:
467  // Note: you would initialize a new copy of this object for each new
468  // utterance.
469  // The frame-subsampling-factor is used for newer nnet3 models, especially
470  // chain models, when the frame-rate of the decoder is different from the
471  // frame-rate of the input features. E.g. you might set it to 3 for such
472  // models.
473 
474  OnlineSilenceWeighting(const TransitionModel &trans_model,
475  const OnlineSilenceWeightingConfig &config,
476  int32 frame_subsampling_factor = 1);
477 
478  bool Active() const { return config_.Active(); }
479 
480  // This should be called before GetDeltaWeights, so this class knows about the
481  // traceback info from the decoder. It records the traceback information from
482  // the decoder using its BestPathEnd() and related functions.
483  // It will be instantiated for FST == fst::Fst<fst::StdArc> and fst::GrammarFst.
484  template <typename FST>
485  void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl<FST> &decoder);
486  template <typename FST>
487  void ComputeCurrentTraceback(const LatticeIncrementalOnlineDecoderTpl<FST> &decoder);
488 
489  // Calling this function gets the changes in weight that require us to modify
490  // the stats... the output format is (frame-index, delta-weight).
491  //
492  // The num_frames_ready argument is the number of frames available at
493  // the input (or equivalently, output) of the online iVector feature in the
494  // feature pipeline from the stream start. It may be more than the currently
495  // available decoder traceback.
496  //
497  // The first_decoder_frame is the offset from the start of the stream in
498  // pipeline frames when decoder was restarted last time. We do not change
499  // weight for the frames earlier than first_decoder_frame. Set it to 0 in
500  // case of compilation error to reproduce the previous behavior or for a
501  // single utterance decoding.
502  //
503  // How many frames of weights it outputs depends on how much "num_frames_ready"
504  // increased since last time we called this function, and whether the decoder
505  // traceback changed. Negative delta_weights might occur if frames previously
506  // classified as non-silence become classified as silence if the decoder's
507  // traceback changes. You must call this function with "num_frames_ready"
508  // arguments that only increase, not decrease, with time. You would provide
509  // this output to class OnlineIvectorFeature by calling its function
510  // UpdateFrameWeights with the output.
511  //
512  // Returned frame-index is in pipeline frames from the pipeline start.
513  void GetDeltaWeights(
514  int32 num_frames_ready, int32 first_decoder_frame,
515  std::vector<std::pair<int32, BaseFloat> > *delta_weights);
516 
517  // A method for backward compatibility, same as above, but for a single
518  // utterance.
520  int32 num_frames_ready,
521  std::vector<std::pair<int32, BaseFloat> > *delta_weights) {
522  GetDeltaWeights(num_frames_ready, 0, delta_weights);
523  }
524 
525  private:
528 
530 
531  unordered_set<int32> silence_phones_;
532 
533  struct FrameInfo {
534  // The only reason we need the token pointer is to know far back we have to
535  // trace before the traceback is the same as what we previously traced back.
536  void *token;
538  // current_weight is the weight we've previously told the iVector
539  // extractor to use for this frame, if any. It may not equal the
540  // weight we "want" it to use (any difference between the two will
541  // be output when the user calls GetDeltaWeights().
543  FrameInfo(): token(NULL), transition_id(-1), current_weight(0.0) {}
544  };
545 
546  // This contains information about any previously computed traceback;
547  // when the traceback changes we use this variable to compare it with the
548  // previous traceback.
549  // It's indexed at the frame-rate of the decoder (may be different
550  // by 'frame_subsampling_factor_' from the frame-rate of the features.
551  std::vector<FrameInfo> frame_info_;
552 
553  // This records how many frames have been output and that currently reflect
554  // the traceback accurately. It is used to avoid GetDeltaWeights() having to
555  // visit each frame as far back as t = 0, each time it is called.
556  // GetDeltaWeights() sets this to the number of frames that it output, and
557  // ComputeCurrentTraceback() then reduces it to however far it traced back.
558  // However, we may have to go further back in time than this in order to
559  // properly honor the "max-state-duration" config. This, if needed, is done
560  // in GetDeltaWeights() before outputting the delta weights.
562 };
563 
564 
566 } // namespace kaldi
567 
568 #endif // KALDI_ONLINE2_ONLINE_IVECTOR_FEATURE_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This struct contains various things that are needed (as const references) by class OnlineIvectorExtra...
Vector< double > current_ivector_
Most recently estimated iVector, will have been estimated at the greatest time t where t <= num_frame...
LatticeIncrementalOnlineDecoderTpl is as LatticeIncrementalDecoderTpl but also supports an efficient ...
OnlineIvectorExtractorAdaptationState(const OnlineIvectorExtractionInfo &info)
This constructor initializes adaptation-state with no prior speaker history.
This class does an online version of the cepstral mean and [optionally] variance, but note that this ...
void GetDeltaWeights(int32 num_frames_ready, std::vector< std::pair< int32, BaseFloat > > *delta_weights)
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
OnlineFeatureInterface * base_
kaldi::int32 int32
This class helps us to efficiently estimate iVectors in situations where the data is coming in frame ...
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type)
Definition: kaldi-utils.h:121
int32 num_frames_stats_
num_frames_stats_ is the number of frames of data we have already accumulated from this utterance and...
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
double tot_ubm_loglike_
The following is only needed for diagnostics.
OnlineFeatureInterface * lda_normalized_
std::vector< Vector< BaseFloat > *> ivectors_history_
if info_.use_most_recent_ivector == false, we need to store the iVector we estimated each info_...
void RegisterWithPrefix(std::string prefix, OptionsItf *opts)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
OnlineIvectorEstimationStats ivector_stats_
the iVector estimation stats
OnlineIvectorEstimationStats ivector_stats
Stats for online iVector estimation.
bool updated_with_no_delta_weights_
The following is also used to detect wrong usage of this class; it&#39;s set to true if UpdateStatsUntilF...
int32 most_recent_frame_with_weight_
if delta_weights_ was ever called, this keeps track of the most recent frame that ever had a weight...
const OnlineSilenceWeightingConfig & config_
Struct OnlineCmvnState stores the state of CMVN adaptation between utterances (but not the state of t...
std::vector< BaseFloat > current_frame_weight_debug_
this is only used for validating that the frame-weighting code is not buggy.
const OnlineIvectorExtractionInfo & info_
LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also supports an efficient way to get...
std::priority_queue< std::pair< int32, BaseFloat >, std::vector< std::pair< int32, BaseFloat > >, std::greater< std::pair< int32, BaseFloat > > > delta_weights_
delta_weights_ is written to by UpdateFrameWeights, in the case where the iVector estimation is silen...
bool delta_weights_provided_
delta_weights_provided_ is set to true if UpdateFrameWeights was ever called; it&#39;s used to detect wro...
std::vector< OnlineFeatureInterface * > to_delete_
unordered_set< int32 > silence_phones_
Definition for Gaussian Mixture Model with diagonal covariances.
Definition: diag-gmm.h:42
OnlineFeatureInterface is an interface for online feature processing (it is also usable in the offlin...
This class includes configuration variables relating to the online iVector extraction, but not including configuration for the "base feature", i.e.
const TransitionModel & trans_model_
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
std::vector< FrameInfo > frame_info_
OnlineFeatureInterface * lda_
OnlineIvectorFeature is an online feature-extraction class that&#39;s responsible for extracting iVectors...