ivector-mean.cc
Go to the documentation of this file.
1 // ivectorbin/ivector-mean.cc
2 
3 // Copyright 2013-2014 Daniel Povey
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 
24 
25 int main(int argc, char *argv[]) {
26  using namespace kaldi;
27  typedef kaldi::int32 int32;
28  try {
29  const char *usage =
30  "With 3 or 4 arguments, averages iVectors over all the\n"
31  "utterances of each speaker using the spk2utt file.\n"
32  "Input the spk2utt file and a set of iVectors indexed by\n"
33  "utterance; output is iVectors indexed by speaker. If 4\n"
34  "arguments are given, extra argument is a table for the number\n"
35  "of utterances per speaker (can be useful for PLDA). If 2\n"
36  "arguments are given, computes the mean of all input files and\n"
37  "writes out the mean vector.\n"
38  "\n"
39  "Usage: ivector-mean <spk2utt-rspecifier> <ivector-rspecifier> "
40  "<ivector-wspecifier> [<num-utt-wspecifier>]\n"
41  "or: ivector-mean <ivector-rspecifier> <mean-wxfilename>\n"
42  "e.g.: ivector-mean data/spk2utt exp/ivectors.ark exp/spk_ivectors.ark exp/spk_num_utts.ark\n"
43  "or: ivector-mean exp/ivectors.ark exp/mean.vec\n"
44  "See also: ivector-subtract-global-mean\n";
45 
46  ParseOptions po(usage);
47  bool binary_write = false;
48  po.Register("binary", &binary_write, "If true, write output in binary "
49  "(only applicable when writing files, not archives/tables.");
50 
51  po.Read(argc, argv);
52 
53  if (po.NumArgs() < 2 || po.NumArgs() > 4) {
54  po.PrintUsage();
55  exit(1);
56  }
57 
58  if (po.NumArgs() == 2) {
59  // Compute the mean of the input vectors and write it out.
60  std::string ivector_rspecifier = po.GetArg(1),
61  mean_wxfilename = po.GetArg(2);
62  int32 num_done = 0;
63  SequentialBaseFloatVectorReader ivector_reader(ivector_rspecifier);
64  Vector<double> sum;
65  for (; !ivector_reader.Done(); ivector_reader.Next()) {
66  if (sum.Dim() == 0) sum.Resize(ivector_reader.Value().Dim());
67  sum.AddVec(1.0, ivector_reader.Value());
68  num_done++;
69  }
70  if (num_done == 0) {
71  KALDI_ERR << "No iVectors read";
72  } else {
73  sum.Scale(1.0 / num_done);
74  WriteKaldiObject(sum, mean_wxfilename, binary_write);
75  return 0;
76  }
77  } else {
78  std::string spk2utt_rspecifier = po.GetArg(1),
79  ivector_rspecifier = po.GetArg(2),
80  ivector_wspecifier = po.GetArg(3),
81  num_utts_wspecifier = po.GetOptArg(4);
82 
83  double spk_sumsq = 0.0;
84  Vector<double> spk_sum;
85 
86  int64 num_spk_done = 0, num_spk_err = 0,
87  num_utt_done = 0, num_utt_err = 0;
88 
89  RandomAccessBaseFloatVectorReader ivector_reader(ivector_rspecifier);
90  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
91  BaseFloatVectorWriter ivector_writer(ivector_wspecifier);
92  Int32Writer num_utts_writer(num_utts_wspecifier);
93 
94  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
95  std::string spk = spk2utt_reader.Key();
96  const std::vector<std::string> &uttlist = spk2utt_reader.Value();
97  if (uttlist.empty()) {
98  KALDI_ERR << "Speaker with no utterances.";
99  }
100  Vector<BaseFloat> spk_mean;
101  int32 utt_count = 0;
102  for (size_t i = 0; i < uttlist.size(); i++) {
103  std::string utt = uttlist[i];
104  if (!ivector_reader.HasKey(utt)) {
105  KALDI_WARN << "No iVector present in input for utterance " << utt;
106  num_utt_err++;
107  } else {
108  if (utt_count == 0) {
109  spk_mean = ivector_reader.Value(utt);
110  } else {
111  spk_mean.AddVec(1.0, ivector_reader.Value(utt));
112  }
113  num_utt_done++;
114  utt_count++;
115  }
116  }
117  if (utt_count == 0) {
118  KALDI_WARN << "Not producing output for speaker " << spk
119  << " since no utterances had iVectors";
120  num_spk_err++;
121  } else {
122  spk_mean.Scale(1.0 / utt_count);
123  ivector_writer.Write(spk, spk_mean);
124  if (num_utts_wspecifier != "")
125  num_utts_writer.Write(spk, utt_count);
126  num_spk_done++;
127  spk_sumsq += VecVec(spk_mean, spk_mean);
128  if (spk_sum.Dim() == 0)
129  spk_sum.Resize(spk_mean.Dim());
130  spk_sum.AddVec(1.0, spk_mean);
131  }
132  }
133 
134  KALDI_LOG << "Computed mean of " << num_spk_done << " speakers ("
135  << num_spk_err << " with no utterances), consisting of "
136  << num_utt_done << " utterances (" << num_utt_err
137  << " absent from input).";
138 
139  if (num_spk_done != 0) {
140  spk_sumsq /= num_spk_done;
141  spk_sum.Scale(1.0 / num_spk_done);
142  double mean_length = spk_sum.Norm(2.0),
143  spk_length = sqrt(spk_sumsq),
144  norm_spk_length = spk_length / sqrt(spk_sum.Dim());
145  KALDI_LOG << "Norm of mean of speakers is " << mean_length
146  << ", root-mean-square speaker-iVector length divided by "
147  << "sqrt(dim) is " << norm_spk_length;
148  }
149 
150  return (num_spk_done != 0 ? 0 : 1);
151  }
152  } catch(const std::exception &e) {
153  std::cerr << e.what();
154  return -1;
155  }
156 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
Definition: ivector-mean.cc:25
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Real Norm(Real p) const
Compute the p-th norm of the vector.
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:153
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
std::string GetOptArg(int param) const