vector-sum.cc
Go to the documentation of this file.
1 // bin/vector-sum.cc
2 
3 // Copyright 2014 Vimal Manohar
4 // 2014-2018 Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <vector>
22 #include <string>
23 
24 using std::vector;
25 using std::string;
26 
27 #include "base/kaldi-common.h"
28 #include "util/common-utils.h"
29 #include "matrix/kaldi-vector.h"
31 
32 
33 namespace kaldi {
34 
35 // sums a bunch of archives to produce one archive
37  int32 num_args = po.NumArgs();
38  std::string vector_in_fn1 = po.GetArg(1),
39  vector_out_fn = po.GetArg(num_args);
40 
41  // Output vector
42  BaseFloatVectorWriter vector_writer(vector_out_fn);
43 
44  // Input vectors
45  SequentialBaseFloatVectorReader vector_reader1(vector_in_fn1);
46  std::vector<RandomAccessBaseFloatVectorReader*> vector_readers(num_args-2,
47  static_cast<RandomAccessBaseFloatVectorReader*>(NULL));
48  std::vector<std::string> vector_in_fns(num_args-2);
49  for (int32 i = 2; i < num_args; ++i) {
50  vector_readers[i-2] = new RandomAccessBaseFloatVectorReader(po.GetArg(i));
51  vector_in_fns[i-2] = po.GetArg(i);
52  }
53 
54  int32 n_utts = 0, n_total_vectors = 0,
55  n_success = 0, n_missing = 0, n_other_errors = 0;
56 
57  for (; !vector_reader1.Done(); vector_reader1.Next()) {
58  std::string key = vector_reader1.Key();
59  Vector<BaseFloat> vector1 = vector_reader1.Value();
60  vector_reader1.FreeCurrent();
61  n_utts++;
62  n_total_vectors++;
63 
64  Vector<BaseFloat> vector_out(vector1);
65 
66  for (int32 i = 0; i < num_args-2; ++i) {
67  if (vector_readers[i]->HasKey(key)) {
68  Vector<BaseFloat> vector2 = vector_readers[i]->Value(key);
69  n_total_vectors++;
70  if (vector2.Dim() == vector_out.Dim()) {
71  vector_out.AddVec(1.0, vector2);
72  } else {
73  KALDI_WARN << "Dimension mismatch for utterance " << key
74  << " : " << vector2.Dim() << " for "
75  << "system " << (i + 2) << ", rspecifier: "
76  << vector_in_fns[i] << " vs " << vector_out.Dim()
77  << " primary vector, rspecifier:" << vector_in_fn1;
78  n_other_errors++;
79  }
80  } else {
81  KALDI_WARN << "No vector found for utterance " << key << " for "
82  << "system " << (i + 2) << ", rspecifier: "
83  << vector_in_fns[i];
84  n_missing++;
85  }
86  }
87 
88  vector_writer.Write(key, vector_out);
89  n_success++;
90  }
91 
92  KALDI_LOG << "Processed " << n_utts << " utterances: with a total of "
93  << n_total_vectors << " vectors across " << (num_args-1)
94  << " different systems";
95  KALDI_LOG << "Produced output for " << n_success << " utterances; "
96  << n_missing << " total missing vectors";
97 
98  DeletePointers(&vector_readers);
99 
100  return (n_success != 0 && n_missing < (n_success - n_missing)) ? 0 : 1;
101 }
102 
104  bool binary,
105  bool average = false) {
106  KALDI_ASSERT(po.NumArgs() == 2);
107  KALDI_ASSERT(ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier &&
108  "vector-sum: first argument must be an rspecifier");
109  // if next assert fails it would be bug in the code as otherwise we shouldn't
110  // be called.
111  KALDI_ASSERT(ClassifyWspecifier(po.GetArg(2), NULL, NULL, NULL) ==
112  kNoWspecifier);
113 
114  SequentialBaseFloatVectorReader vec_reader(po.GetArg(1));
115 
116  Vector<double> sum;
117 
118  int32 num_done = 0, num_err = 0;
119 
120  for (; !vec_reader.Done(); vec_reader.Next()) {
121  const Vector<BaseFloat> &vec = vec_reader.Value();
122  if (vec.Dim() == 0) {
123  KALDI_WARN << "Zero vector input for key " << vec_reader.Key();
124  num_err++;
125  } else {
126  if (sum.Dim() == 0) sum.Resize(vec.Dim());
127  if (sum.Dim() != vec.Dim()) {
128  KALDI_WARN << "Dimension mismatch for key " << vec_reader.Key()
129  << ": " << vec.Dim() << " vs. " << sum.Dim();
130  num_err++;
131  } else {
132  sum.AddVec(1.0, vec);
133  num_done++;
134  }
135  }
136  }
137 
138  if (num_done > 0 && average) sum.Scale(1.0 / num_done);
139 
140  Vector<BaseFloat> sum_float(sum);
141  WriteKaldiObject(sum_float, po.GetArg(2), binary);
142 
143  KALDI_LOG << "Summed " << num_done << " vectors, "
144  << num_err << " with errors; wrote sum to "
145  << PrintableWxfilename(po.GetArg(2));
146  return (num_done > 0 && num_err < num_done) ? 0 : 1;
147 }
148 
149 // sum a bunch of single files to produce a single file [including
150 // extended filenames, of course]
152  bool binary) {
153  KALDI_ASSERT(po.NumArgs() >= 2);
154  for (int32 i = 1; i < po.NumArgs(); i++) {
155  if (ClassifyRspecifier(po.GetArg(i), NULL, NULL) != kNoRspecifier) {
156  KALDI_ERR << "Wrong usage (type 3): if first and last arguments are not "
157  << "tables, the intermediate arguments must not be tables.";
158  }
159  }
160  if (ClassifyWspecifier(po.GetArg(po.NumArgs()), NULL, NULL, NULL) !=
161  kNoWspecifier) {
162  KALDI_ERR << "Wrong usage (type 3): if first and last arguments are not "
163  << "tables, the intermediate arguments must not be tables.";
164  }
165 
166  Vector<BaseFloat> sum;
167  for (int32 i = 1; i < po.NumArgs(); i++) {
168  Vector<BaseFloat> this_vec;
169  ReadKaldiObject(po.GetArg(i), &this_vec);
170  if (sum.Dim() < this_vec.Dim())
171  sum.Resize(this_vec.Dim(), kCopyData);;
172  sum.AddVec(1.0, this_vec);
173  }
174  WriteKaldiObject(sum, po.GetArg(po.NumArgs()), binary);
175  KALDI_LOG << "Summed " << (po.NumArgs() - 1) << " vectors; "
176  << "wrote sum to " << PrintableWxfilename(po.GetArg(po.NumArgs()));
177  return 0;
178 }
179 
180 
181 } // namespace kaldi
182 
183 
184 int main(int argc, char *argv[]) {
185  try {
186  using namespace kaldi;
187 
188  const char *usage =
189  "Add vectors (e.g. weights, transition-accs; speaker vectors)\n"
190  "If you need to scale the inputs, use vector-scale on the inputs\n"
191  "\n"
192  "Type one usage:\n"
193  " vector-sum [options] <vector-in-rspecifier1> [<vector-in-rspecifier2>"
194  " <vector-in-rspecifier3> ...] <vector-out-wspecifier>\n"
195  " e.g.: vector-sum ark:1.weights ark:2.weights ark:combine.weights\n"
196  "Type two usage (sums a single table input to produce a single output):\n"
197  " vector-sum [options] <vector-in-rspecifier> <vector-out-wxfilename>\n"
198  " e.g.: vector-sum --binary=false vecs.ark sum.vec\n"
199  "Type three usage (sums single-file inputs to produce a single output):\n"
200  " vector-sum [options] <vector-in-rxfilename1> <vector-in-rxfilename2> ..."
201  " <vector-out-wxfilename>\n"
202  " e.g.: vector-sum --binary=false 1.vec 2.vec 3.vec sum.vec\n"
203  "See also: copy-vector, dot-weights\n";
204 
205  bool binary, average = false;
206 
207  ParseOptions po(usage);
208 
209  po.Register("binary", &binary, "If true, write output as binary (only "
210  "relevant for usage types two or three");
211  po.Register("average", &average, "Do average instead of sum");
212 
213  po.Read(argc, argv);
214 
215  int32 N = po.NumArgs(), exit_status;
216 
217  if (po.NumArgs() >= 2 &&
218  ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) != kNoWspecifier) {
219  // output to table.
220  exit_status = TypeOneUsage(po);
221  } else if (po.NumArgs() == 2 &&
222  ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier &&
223  ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) ==
224  kNoWspecifier) {
225  // input from a single table, output not to table.
226  exit_status = TypeTwoUsage(po, binary, average);
227  } else if (po.NumArgs() >= 2 &&
228  ClassifyRspecifier(po.GetArg(1), NULL, NULL) == kNoRspecifier &&
229  ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) ==
230  kNoWspecifier) {
231  // summing flat files.
232  exit_status = TypeThreeUsage(po, binary);
233  } else {
234  po.PrintUsage();
235  exit(1);
236  }
237  return exit_status;
238  } catch(const std::exception &e) {
239  std::cerr << e.what();
240  return -1;
241  }
242 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
RspecifierType ClassifyRspecifier(const std::string &rspecifier, std::string *rxfilename, RspecifierOptions *opts)
Definition: kaldi-table.cc:225
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
int32 TypeTwoUsage(const ParseOptions &po, bool binary)
Definition: matrix-sum.cc:179
int32 TypeOneUsage(const ParseOptions &po, BaseFloat scale1, BaseFloat scale2)
Definition: matrix-sum.cc:30
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int main(int argc, char *argv[])
Definition: vector-sum.cc:184
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
WspecifierType ClassifyWspecifier(const std::string &wspecifier, std::string *archive_wxfilename, std::string *script_wxfilename, WspecifierOptions *opts)
Definition: kaldi-table.cc:135
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
std::string PrintableWxfilename(const std::string &wxfilename)
PrintableWxfilename turns the wxfilename into a more human-readable form for error reporting...
Definition: kaldi-io.cc:73
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
RandomAccessTableReader< KaldiObjectHolder< Vector< BaseFloat > > > RandomAccessBaseFloatVectorReader
Definition: table-types.h:62
int32 TypeThreeUsage(const ParseOptions &po, bool binary, bool average)
Definition: matrix-sum.cc:226