ivector-extract-online2.cc File Reference
Include dependency graph for ivector-extract-online2.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file ivector-extract-online2.cc.

References MatrixBase< Real >::ColRange(), VectorBase< Real >::Dim(), OnlineIvectorFeature::Dim(), SequentialTableReader< Holder >::Done(), OnlineIvectorExtractionInfo::ExpectedFeatureDim(), OnlineIvectorExtractionInfo::extractor, kaldi::g_num_threads, OnlineIvectorFeature::GetAdaptationState(), ParseOptions::GetArg(), OnlineIvectorFeature::GetFrame(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, OnlineIvectorExtractionConfig::ivector_period, IvectorExtractor::IvectorDim(), KALDI_LOG, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), rnnlm::n, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), OnlineIvectorFeature::ObjfImprPerFrame(), ParseOptions::PrintUsage(), ParseOptions::Read(), ParseOptions::Register(), OnlineIvectorExtractionConfig::Register(), OnlineIvectorFeature::SetAdaptationState(), OnlineIvectorFeature::UbmLogLikePerFrame(), OnlineIvectorFeature::UpdateFrameWeights(), OnlineIvectorExtractionConfig::use_most_recent_ivector, RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

28  {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31  typedef kaldi::int64 int64;
32  try {
33  const char *usage =
34  "Extract iVectors for utterances every --ivector-period frames, using a trained\n"
35  "iVector extractor and features and Gaussian-level posteriors. Similar to\n"
36  "ivector-extract-online but uses the actual online decoder code to do it,\n"
37  "and does everything in-memory instead of using multiple processes.\n"
38  "Note: the value of the --use-most-recent-ivector config variable is ignored\n"
39  "it's set to false. The <spk2utt-rspecifier> is mandatory, to simplify the code;\n"
40  "if you want to do it separately per utterance, just make it of the form\n"
41  "<utterance-id> <utterance-id>.\n"
42  "The iVectors are output as an archive of matrices, indexed by utterance-id;\n"
43  "each row corresponds to an iVector. If --repeat=true, outputs the whole matrix\n"
44  "of iVectors, not just every (ivector-period)'th frame\n"
45  "The input features are the raw, non-cepstral-mean-normalized features, e.g. MFCC.\n"
46  "\n"
47  "Usage: ivector-extract-online2 [options] <spk2utt-rspecifier> <feature-rspecifier> <ivector-wspecifier>\n"
48  "e.g.: \n"
49  " ivector-extract-online2 --config=exp/nnet2_online/nnet_online/conf/ivector_extractor.conf \\\n"
50  " ark:data/train/spk2utt scp:data/train/feats.scp ark,t:ivectors.1.ark\n";
51 
52  ParseOptions po(usage);
53 
54  OnlineIvectorExtractionConfig ivector_config;
55  ivector_config.Register(&po);
56 
57  g_num_threads = 8;
58  bool repeat = false;
59  int32 length_tolerance = 0;
60  std::string frame_weights_rspecifier;
61 
62  po.Register("num-threads", &g_num_threads,
63  "Number of threads to use for computing derived variables "
64  "of iVector extractor, at process start-up.");
65  po.Register("repeat", &repeat,
66  "If true, output the same number of iVectors as input frames "
67  "(including repeated data).");
68  po.Register("frame-weights-rspecifier", &frame_weights_rspecifier,
69  "Archive of frame weights to scale stats");
70  po.Register("length-tolerance", &length_tolerance,
71  "Tolerance on the difference in number of frames "
72  "for feats and frame weights");
73 
74  po.Read(argc, argv);
75 
76  if (po.NumArgs() != 3) {
77  po.PrintUsage();
78  exit(1);
79  }
80 
81  std::string spk2utt_rspecifier = po.GetArg(1),
82  feature_rspecifier = po.GetArg(2),
83  ivectors_wspecifier = po.GetArg(3);
84 
85  double tot_ubm_loglike = 0.0, tot_objf_impr = 0.0, tot_t = 0.0,
86  tot_length = 0.0, tot_length_utt_end = 0.0;
87  int32 num_done = 0, num_err = 0;
88 
89  ivector_config.use_most_recent_ivector = false;
90  OnlineIvectorExtractionInfo ivector_info(ivector_config);
91 
92  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
93  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
94  RandomAccessBaseFloatVectorReader frame_weights_reader(frame_weights_rspecifier);
95  BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier);
96 
97  bool warned_dim = false;
98  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
99  std::string spk = spk2utt_reader.Key();
100  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
101  OnlineIvectorExtractorAdaptationState adaptation_state(
102  ivector_info);
103  for (size_t i = 0; i < uttlist.size(); i++) {
104  std::string utt = uttlist[i];
105  if (!feature_reader.HasKey(utt)) {
106  KALDI_WARN << "Did not find audio for utterance " << utt;
107  num_err++;
108  continue;
109  }
110  const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
111 
112  int32 feat_dim = feats.NumCols();
113  if (feat_dim == ivector_info.ExpectedFeatureDim() + 3) {
114  if (!warned_dim) {
115  KALDI_WARN << "Feature dimension is too large by 3, assuming there are "
116  "pitch features and removing the last 3 dims.";
117  warned_dim = true;
118  }
119  feat_dim -= 3;
120  }
121 
122  SubMatrix<BaseFloat> range = feats.ColRange(0, feat_dim);
123  OnlineMatrixFeature matrix_feature(range);
124 
125  OnlineIvectorFeature ivector_feature(ivector_info,
126  &matrix_feature);
127 
128  ivector_feature.SetAdaptationState(adaptation_state);
129 
130  if (!frame_weights_rspecifier.empty()) {
131  if (!frame_weights_reader.HasKey(utt)) {
132  KALDI_WARN << "Did not find weights for utterance " << utt;
133  num_err++;
134  continue;
135  }
136  const Vector<BaseFloat> &weights = frame_weights_reader.Value(utt);
137 
138  if (std::abs(weights.Dim() - feats.NumRows()) > length_tolerance) {
139  num_err++;
140  continue;
141  }
142 
143  std::vector<std::pair<int32, BaseFloat> > frame_weights;
144  for (int32 i = 0; i < feats.NumRows(); i++) {
145  if (i < weights.Dim())
146  frame_weights.push_back(std::make_pair(i, weights(i)));
147  else
148  frame_weights.push_back(std::make_pair(i, 0.0));
149  }
150 
151 
152  ivector_feature.UpdateFrameWeights(frame_weights);
153  }
154 
155  int32 T = feats.NumRows(),
156  n = (repeat ? 1 : ivector_config.ivector_period),
157  num_ivectors = (T + n - 1) / n;
158 
159  Matrix<BaseFloat> ivectors(num_ivectors,
160  ivector_feature.Dim());
161 
162  for (int32 i = 0; i < num_ivectors; i++) {
163  int32 t = i * n;
164  SubVector<BaseFloat> ivector(ivectors, i);
165  ivector_feature.GetFrame(t, &ivector);
166  }
167  // Update diagnostics.
168 
169  tot_ubm_loglike += T * ivector_feature.UbmLogLikePerFrame();
170  tot_objf_impr += T * ivector_feature.ObjfImprPerFrame();
171  tot_length_utt_end += T * ivectors.Row(num_ivectors - 1).Norm(2.0);
172  for (int32 i = 0; i < num_ivectors; i++)
173  tot_length += T * ivectors.Row(i).Norm(2.0) / num_ivectors;
174  tot_t += T;
175  KALDI_VLOG(2) << "For utterance " << utt << " of speaker " << spk
176  << ", UBM loglike/frame was "
177  << ivector_feature.UbmLogLikePerFrame()
178  << ", iVector length (at utterance end) was "
179  << ivectors.Row(num_ivectors-1).Norm(2.0)
180  << ", objf improvement/frame from iVector estimation was "
181  << ivector_feature.ObjfImprPerFrame();
182 
183  ivector_feature.GetAdaptationState(&adaptation_state);
184  ivector_writer.Write(utt, ivectors);
185  num_done++;
186  }
187  }
188 
189  KALDI_LOG << "Estimated iVectors for " << num_done << " files, " << num_err
190  << " with errors.";
191  KALDI_LOG << "Average objective-function improvement was "
192  << (tot_objf_impr / tot_t) << " per frame, over "
193  << tot_t << " frames (weighted).";
194  KALDI_LOG << "Average iVector length was " << (tot_length / tot_t)
195  << " and at utterance-end was " << (tot_length_utt_end / tot_t)
196  << ", over " << tot_t << " frames (weighted); "
197  << " expected length is "
198  << sqrt(ivector_info.extractor.IvectorDim());
199 
200  return (num_done != 0 ? 0 : 1);
201  } catch(const std::exception &e) {
202  std::cerr << e.what();
203  return -1;
204  }
205 }
This class takes a Matrix<BaseFloat> and wraps it as an OnlineFeatureInterface: this can be useful wh...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This struct contains various things that are needed (as const references) by class OnlineIvectorExtra...
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
int32 g_num_threads
Definition: kaldi-thread.cc:25
This class stores the adaptation state from the online iVector extractor, which can help you to initi...
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
SubMatrix< Real > ColRange(const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Definition: kaldi-matrix.h:213
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
This class includes configuration variables relating to the online iVector extraction, but not including configuration for the "base feature", i.e.
#define KALDI_LOG
Definition: kaldi-error.h:153
Sub-matrix representation.
Definition: kaldi-matrix.h:988
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
OnlineIvectorFeature is an online feature-extraction class that&#39;s responsible for extracting iVectors...