ivector-plda-scoring.cc File Reference
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "ivector/plda.h"
Include dependency graph for ivector-plda-scoring.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 26 of file ivector-plda-scoring.cc.

References Plda::Dim(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), KALDI_ERR, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), Plda::LogLikelihoodRatio(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), PldaConfig::Register(), ParseOptions::Register(), kaldi::SplitStringToVector(), Output::Stream(), Plda::TransformIvector(), RandomAccessTableReader< Holder >::Value(), and SequentialTableReader< Holder >::Value().

26  {
27  using namespace kaldi;
28  typedef kaldi::int32 int32;
29  typedef std::string string;
30  try {
31  const char *usage =
32  "Computes log-likelihood ratios for trials using PLDA model\n"
33  "Note: the 'trials-file' has lines of the form\n"
34  "<key1> <key2>\n"
35  "and the output will have the form\n"
36  "<key1> <key2> [<dot-product>]\n"
37  "(if either key could not be found, the dot-product field in the output\n"
38  "will be absent, and this program will print a warning)\n"
39  "For training examples, the input is the iVectors averaged over speakers;\n"
40  "a separate archive containing the number of utterances per speaker may be\n"
41  "optionally supplied using the --num-utts option; this affects the PLDA\n"
42  "scoring (if not supplied, it defaults to 1 per speaker).\n"
43  "\n"
44  "Usage: ivector-plda-scoring <plda> <train-ivector-rspecifier> <test-ivector-rspecifier>\n"
45  " <trials-rxfilename> <scores-wxfilename>\n"
46  "\n"
47  "e.g.: ivector-plda-scoring --num-utts=ark:exp/train/num_utts.ark plda "
48  "ark:exp/train/spk_ivectors.ark ark:exp/test/ivectors.ark trials scores\n"
49  "See also: ivector-compute-dot-products, ivector-compute-plda\n";
50 
51  ParseOptions po(usage);
52 
53  std::string num_utts_rspecifier;
54 
55  PldaConfig plda_config;
56  plda_config.Register(&po);
57  po.Register("num-utts", &num_utts_rspecifier, "Table to read the number of "
58  "utterances per speaker, e.g. ark:num_utts.ark\n");
59 
60  po.Read(argc, argv);
61 
62  if (po.NumArgs() != 5) {
63  po.PrintUsage();
64  exit(1);
65  }
66 
67  std::string plda_rxfilename = po.GetArg(1),
68  train_ivector_rspecifier = po.GetArg(2),
69  test_ivector_rspecifier = po.GetArg(3),
70  trials_rxfilename = po.GetArg(4),
71  scores_wxfilename = po.GetArg(5);
72 
73  // diagnostics:
74  double tot_test_renorm_scale = 0.0, tot_train_renorm_scale = 0.0;
75  int64 num_train_ivectors = 0, num_train_errs = 0, num_test_ivectors = 0;
76 
77  int64 num_trials_done = 0, num_trials_err = 0;
78 
79  Plda plda;
80  ReadKaldiObject(plda_rxfilename, &plda);
81 
82  int32 dim = plda.Dim();
83 
84  SequentialBaseFloatVectorReader train_ivector_reader(train_ivector_rspecifier);
85  SequentialBaseFloatVectorReader test_ivector_reader(test_ivector_rspecifier);
86  RandomAccessInt32Reader num_utts_reader(num_utts_rspecifier);
87 
88  typedef unordered_map<string, Vector<BaseFloat>*, StringHasher> HashType;
89 
90  // These hashes will contain the iVectors in the PLDA subspace
91  // (that makes the within-class variance unit and diagonalizes the
92  // between-class covariance). They will also possibly be length-normalized,
93  // depending on the config.
94  HashType train_ivectors, test_ivectors;
95 
96  KALDI_LOG << "Reading train iVectors";
97  for (; !train_ivector_reader.Done(); train_ivector_reader.Next()) {
98  std::string spk = train_ivector_reader.Key();
99  if (train_ivectors.count(spk) != 0) {
100  KALDI_ERR << "Duplicate training iVector found for speaker " << spk;
101  }
102  const Vector<BaseFloat> &ivector = train_ivector_reader.Value();
103  int32 num_examples;
104  if (!num_utts_rspecifier.empty()) {
105  if (!num_utts_reader.HasKey(spk)) {
106  KALDI_WARN << "Number of utterances not given for speaker " << spk;
107  num_train_errs++;
108  continue;
109  }
110  num_examples = num_utts_reader.Value(spk);
111  } else {
112  num_examples = 1;
113  }
114  Vector<BaseFloat> *transformed_ivector = new Vector<BaseFloat>(dim);
115 
116  tot_train_renorm_scale += plda.TransformIvector(plda_config, ivector,
117  num_examples,
118  transformed_ivector);
119  train_ivectors[spk] = transformed_ivector;
120  num_train_ivectors++;
121  }
122  KALDI_LOG << "Read " << num_train_ivectors << " training iVectors, "
123  << "errors on " << num_train_errs;
124  if (num_train_ivectors == 0)
125  KALDI_ERR << "No training iVectors present.";
126  KALDI_LOG << "Average renormalization scale on training iVectors was "
127  << (tot_train_renorm_scale / num_train_ivectors);
128 
129  KALDI_LOG << "Reading test iVectors";
130  for (; !test_ivector_reader.Done(); test_ivector_reader.Next()) {
131  std::string utt = test_ivector_reader.Key();
132  if (test_ivectors.count(utt) != 0) {
133  KALDI_ERR << "Duplicate test iVector found for utterance " << utt;
134  }
135  const Vector<BaseFloat> &ivector = test_ivector_reader.Value();
136  int32 num_examples = 1; // this value is always used for test (affects the
137  // length normalization in the TransformIvector
138  // function).
139  Vector<BaseFloat> *transformed_ivector = new Vector<BaseFloat>(dim);
140 
141  tot_test_renorm_scale += plda.TransformIvector(plda_config, ivector,
142  num_examples,
143  transformed_ivector);
144  test_ivectors[utt] = transformed_ivector;
145  num_test_ivectors++;
146  }
147  KALDI_LOG << "Read " << num_test_ivectors << " test iVectors.";
148  if (num_test_ivectors == 0)
149  KALDI_ERR << "No test iVectors present.";
150  KALDI_LOG << "Average renormalization scale on test iVectors was "
151  << (tot_test_renorm_scale / num_test_ivectors);
152 
153 
154  Input ki(trials_rxfilename);
155  bool binary = false;
156  Output ko(scores_wxfilename, binary);
157 
158  double sum = 0.0, sumsq = 0.0;
159  std::string line;
160 
161  while (std::getline(ki.Stream(), line)) {
162  std::vector<std::string> fields;
163  SplitStringToVector(line, " \t\n\r", true, &fields);
164  if (fields.size() != 2) {
165  KALDI_ERR << "Bad line " << (num_trials_done + num_trials_err)
166  << "in input (expected two fields: key1 key2): " << line;
167  }
168  std::string key1 = fields[0], key2 = fields[1];
169  if (train_ivectors.count(key1) == 0) {
170  KALDI_WARN << "Key " << key1 << " not present in training iVectors.";
171  num_trials_err++;
172  continue;
173  }
174  if (test_ivectors.count(key2) == 0) {
175  KALDI_WARN << "Key " << key2 << " not present in test iVectors.";
176  num_trials_err++;
177  continue;
178  }
179  const Vector<BaseFloat> *train_ivector = train_ivectors[key1],
180  *test_ivector = test_ivectors[key2];
181 
182  Vector<double> train_ivector_dbl(*train_ivector),
183  test_ivector_dbl(*test_ivector);
184 
185  int32 num_train_examples;
186  if (!num_utts_rspecifier.empty()) {
187  // we already checked that it has this key.
188  num_train_examples = num_utts_reader.Value(key1);
189  } else {
190  num_train_examples = 1;
191  }
192 
193 
194  BaseFloat score = plda.LogLikelihoodRatio(train_ivector_dbl,
195  num_train_examples,
196  test_ivector_dbl);
197  sum += score;
198  sumsq += score * score;
199  num_trials_done++;
200  ko.Stream() << key1 << ' ' << key2 << ' ' << score << std::endl;
201  }
202 
203  for (HashType::iterator iter = train_ivectors.begin();
204  iter != train_ivectors.end(); ++iter)
205  delete iter->second;
206  for (HashType::iterator iter = test_ivectors.begin();
207  iter != test_ivectors.end(); ++iter)
208  delete iter->second;
209 
210 
211  if (num_trials_done != 0) {
212  BaseFloat mean = sum / num_trials_done, scatter = sumsq / num_trials_done,
213  variance = scatter - mean * mean, stddev = sqrt(variance);
214  KALDI_LOG << "Mean score was " << mean << ", standard deviation was "
215  << stddev;
216  }
217  KALDI_LOG << "Processed " << num_trials_done << " trials, " << num_trials_err
218  << " had errors.";
219  return (num_trials_done != 0 ? 0 : 1);
220  } catch(const std::exception &e) {
221  std::cerr << e.what();
222  return -1;
223  }
224 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
double TransformIvector(const PldaConfig &config, const VectorBase< double > &ivector, int32 num_enroll_examples, VectorBase< double > *transformed_ivector) const
Transforms an iVector into a space where the within-class variance is unit and between-class variance...
Definition: plda.cc:120
kaldi::int32 int32
A hashing function object for strings.
Definition: stl-utils.h:248
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
double LogLikelihoodRatio(const VectorBase< double > &transformed_enroll_ivector, int32 num_enroll_utts, const VectorBase< double > &transformed_test_ivector) const
Returns the log-likelihood ratio log (p(test_ivector | same) / p(test_ivector | different)).
Definition: plda.cc:153
void Register(OptionsItf *opts)
Definition: plda.h:56
A class representing a vector.
Definition: kaldi-vector.h:406
int32 Dim() const
Definition: plda.h:140
#define KALDI_LOG
Definition: kaldi-error.h:153