sgmm2-sum-accs.cc
Go to the documentation of this file.
1 // sgmm2bin/sgmm2-sum-accs.cc
2 
3 // Copyright 2009-2012 Saarland University; Microsoft Corporation
4 // Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "util/common-utils.h"
23 #include "hmm/transition-model.h"
24 
25 
26 int main(int argc, char *argv[]) {
27  try {
28  typedef kaldi::int32 int32;
29 
30  const char *usage =
31  "Sum multiple accumulated stats files for SGMM training.\n"
32  "Usage: sgmm2-sum-accs [options] stats-out stats-in1 stats-in2 ...\n";
33 
34  bool binary = true;
35  bool parallel = false;
36  kaldi::ParseOptions po(usage);
37  po.Register("binary", &binary, "Write output in binary mode");
38  po.Register("parallel", &parallel, "If true, the program makes sure to open all "
39  "filehandles before reading for any (useful when summing accs from "
40  "long processes)");
41  po.Read(argc, argv);
42 
43  if (po.NumArgs() < 2) {
44  po.PrintUsage();
45  exit(1);
46  }
47 
48  std::string stats_out_filename = po.GetArg(1);
49  kaldi::Vector<double> transition_accs;
50  kaldi::MleAmSgmm2Accs sgmm_accs;
51 
52  if (parallel) {
53  std::vector<kaldi::Input*> inputs(po.NumArgs() - 1);
54  for (int i = 0; i < po.NumArgs() - 1; i++) {
55  std::string stats_in_filename = po.GetArg(i + 2);
56  inputs[i] = new kaldi::Input(stats_in_filename); // Don't try
57  // to work out binary status yet; this would cause us to wait
58  // for the output of that process. We delay it till later.
59  }
60  for (size_t i = 0; i < po.NumArgs() - 1; i++) {
61  bool b;
62  if (kaldi::InitKaldiInputStream(inputs[i]->Stream(), &b)) {
63  transition_accs.Read(inputs[i]->Stream(), b, true /* add values */);
64  sgmm_accs.Read(inputs[i]->Stream(), b, true /* add values */);
65  delete inputs[i];
66  } else {
67  KALDI_ERR << "Failed to read input stats file " << po.GetArg(i + 2);
68  }
69  }
70  } else {
71  for (int i = 2, max = po.NumArgs(); i <= max; i++) {
72  std::string stats_in_filename = po.GetArg(i);
73  bool binary_read;
74  kaldi::Input ki(stats_in_filename, &binary_read);
75  transition_accs.Read(ki.Stream(), binary_read, true /* add values */);
76  sgmm_accs.Read(ki.Stream(), binary_read, true /* add values */);
77  }
78  }
79 
80  // Write out the accs
81  {
82  kaldi::Output ko(stats_out_filename, binary);
83  transition_accs.Write(ko.Stream(), binary);
84  sgmm_accs.Write(ko.Stream(), binary);
85  }
86 
87  KALDI_LOG << "Written stats to " << stats_out_filename;
88  } catch(const std::exception &e) {
89  std::cerr << e.what() << '\n';
90  return -1;
91  }
92 }
93 
94 
bool InitKaldiInputStream(std::istream &is, bool *binary)
Initialize an opened stream for reading by detecting the binary header and.
Definition: io-funcs-inl.h:306
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &in_stream, bool binary, bool add)
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void Write(std::ostream &out_stream, bool binary) const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
int main(int argc, char *argv[])
Class for the accumulators associated with the phonetic-subspace model parameters.
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.