sgmm2-sum-accs.cc File Reference
Include dependency graph for sgmm2-sum-accs.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 26 of file sgmm2-sum-accs.cc.

References ParseOptions::GetArg(), rnnlm::i, kaldi::InitKaldiInputStream(), KALDI_ERR, KALDI_LOG, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), MleAmSgmm2Accs::Read(), Vector< Real >::Read(), ParseOptions::Register(), Output::Stream(), Input::Stream(), MleAmSgmm2Accs::Write(), and VectorBase< Real >::Write().

26  {
27  try {
28  typedef kaldi::int32 int32;
29 
30  const char *usage =
31  "Sum multiple accumulated stats files for SGMM training.\n"
32  "Usage: sgmm2-sum-accs [options] stats-out stats-in1 stats-in2 ...\n";
33 
34  bool binary = true;
35  bool parallel = false;
36  kaldi::ParseOptions po(usage);
37  po.Register("binary", &binary, "Write output in binary mode");
38  po.Register("parallel", &parallel, "If true, the program makes sure to open all "
39  "filehandles before reading for any (useful when summing accs from "
40  "long processes)");
41  po.Read(argc, argv);
42 
43  if (po.NumArgs() < 2) {
44  po.PrintUsage();
45  exit(1);
46  }
47 
48  std::string stats_out_filename = po.GetArg(1);
49  kaldi::Vector<double> transition_accs;
50  kaldi::MleAmSgmm2Accs sgmm_accs;
51 
52  if (parallel) {
53  std::vector<kaldi::Input*> inputs(po.NumArgs() - 1);
54  for (int i = 0; i < po.NumArgs() - 1; i++) {
55  std::string stats_in_filename = po.GetArg(i + 2);
56  inputs[i] = new kaldi::Input(stats_in_filename); // Don't try
57  // to work out binary status yet; this would cause us to wait
58  // for the output of that process. We delay it till later.
59  }
60  for (size_t i = 0; i < po.NumArgs() - 1; i++) {
61  bool b;
62  if (kaldi::InitKaldiInputStream(inputs[i]->Stream(), &b)) {
63  transition_accs.Read(inputs[i]->Stream(), b, true /* add values */);
64  sgmm_accs.Read(inputs[i]->Stream(), b, true /* add values */);
65  delete inputs[i];
66  } else {
67  KALDI_ERR << "Failed to read input stats file " << po.GetArg(i + 2);
68  }
69  }
70  } else {
71  for (int i = 2, max = po.NumArgs(); i <= max; i++) {
72  std::string stats_in_filename = po.GetArg(i);
73  bool binary_read;
74  kaldi::Input ki(stats_in_filename, &binary_read);
75  transition_accs.Read(ki.Stream(), binary_read, true /* add values */);
76  sgmm_accs.Read(ki.Stream(), binary_read, true /* add values */);
77  }
78  }
79 
80  // Write out the accs
81  {
82  kaldi::Output ko(stats_out_filename, binary);
83  transition_accs.Write(ko.Stream(), binary);
84  sgmm_accs.Write(ko.Stream(), binary);
85  }
86 
87  KALDI_LOG << "Written stats to " << stats_out_filename;
88  } catch(const std::exception &e) {
89  std::cerr << e.what() << '\n';
90  return -1;
91  }
92 }
bool InitKaldiInputStream(std::istream &is, bool *binary)
Initialize an opened stream for reading by detecting the binary header and.
Definition: io-funcs-inl.h:306
void Read(std::istream &in_stream, bool binary, bool add)
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
kaldi::int32 int32
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
#define KALDI_ERR
Definition: kaldi-error.h:147
void Write(std::ostream &out_stream, bool binary) const
Class for the accumulators associated with the phonetic-subspace model parameters.
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.