gmm-ismooth-stats.cc
Go to the documentation of this file.
1 // gmmbin/gmm-ismooth-stats.cc
2 
3 // Copyright 2009-2011 Petr Motlicek Chao Weng
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "gmm/am-diag-gmm.h"
23 #include "tree/context-dep.h"
24 #include "hmm/transition-model.h"
25 #include "gmm/ebw-diag-gmm.h"
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31 
32  const char *usage =
33  "Apply I-smoothing to statistics, e.g. for discriminative training\n"
34  "Usage: gmm-ismooth-stats [options] [--smooth-from-model] [<src-stats-in>|<src-model-in>] <dst-stats-in> <stats-out>\n"
35  "e.g.: gmm-ismooth-stats --tau=100 ml.acc num.acc smoothed.acc\n"
36  "or: gmm-ismooth-stats --tau=50 --smooth-from-model 1.mdl num.acc smoothed.acc\n"
37  "or: gmm-ismooth-stats --tau=100 num.acc num.acc smoothed.acc\n";
38 
39  bool binary_write = false;
40  bool smooth_from_model = false;
41  BaseFloat tau = 100;
42 
43  ParseOptions po(usage);
44  po.Register("binary", &binary_write, "Write output in binary mode");
45  po.Register("smooth-from-model", &smooth_from_model, "If true, "
46  "expect first argument to be a model file");
47  po.Register("tau", &tau, "Tau value for I-smoothing");
48 
49  po.Read(argc, argv);
50 
51  if (po.NumArgs() != 3) {
52  po.PrintUsage();
53  exit(1);
54  }
55 
56  std::string src_stats_or_model_filename = po.GetArg(1),
57  dst_stats_filename = po.GetArg(2),
58  stats_out_filename = po.GetArg(3);
59 
60  double tot_count_before, tot_count_after;
61 
62  if (src_stats_or_model_filename == dst_stats_filename) { // as an optimization, just read once.
63  KALDI_ASSERT(!smooth_from_model);
64  Vector<double> transition_accs;
65  AccumAmDiagGmm stats;
66  {
67  bool binary;
68  Input ki(dst_stats_filename, &binary);
69  transition_accs.Read(ki.Stream(), binary);
70  stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
71  }
72  tot_count_before = stats.TotStatsCount();
73  IsmoothStatsAmDiagGmm(stats, tau, &stats);
74  tot_count_after = stats.TotStatsCount();
75  Output ko(stats_out_filename, binary_write);
76  transition_accs.Write(ko.Stream(), binary_write);
77  stats.Write(ko.Stream(), binary_write);
78  } else if (smooth_from_model) { // Smoothing from model...
79  AmDiagGmm am_gmm;
80  TransitionModel trans_model;
81  Vector<double> dst_transition_accs;
82  AccumAmDiagGmm dst_stats;
83  { // read src model
84  bool binary;
85  Input ki(src_stats_or_model_filename, &binary);
86  trans_model.Read(ki.Stream(), binary);
87  am_gmm.Read(ki.Stream(), binary);
88  }
89  { // read dst stats.
90  bool binary;
91  Input ki(dst_stats_filename, &binary);
92  dst_transition_accs.Read(ki.Stream(), binary);
93  dst_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
94  }
95  tot_count_before = dst_stats.TotStatsCount();
96  IsmoothStatsAmDiagGmmFromModel(am_gmm, tau, &dst_stats);
97  tot_count_after = dst_stats.TotStatsCount();
98  Output ko(stats_out_filename, binary_write);
99  dst_transition_accs.Write(ko.Stream(), binary_write);
100  dst_stats.Write(ko.Stream(), binary_write);
101  } else { // Smooth from stats.
102  Vector<double> src_transition_accs;
103  Vector<double> dst_transition_accs;
104  AccumAmDiagGmm src_stats;
105  AccumAmDiagGmm dst_stats;
106  { // read src stats.
107  bool binary;
108  Input ki(src_stats_or_model_filename, &binary);
109  src_transition_accs.Read(ki.Stream(), binary);
110  src_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
111  }
112  { // read dst stats.
113  bool binary;
114  Input ki(dst_stats_filename, &binary);
115  dst_transition_accs.Read(ki.Stream(), binary);
116  dst_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
117  }
118  tot_count_before = dst_stats.TotStatsCount();
119  IsmoothStatsAmDiagGmm(src_stats, tau, &dst_stats);
120  tot_count_after = dst_stats.TotStatsCount();
121 
122  Output ko(stats_out_filename, binary_write);
123  dst_transition_accs.Write(ko.Stream(), binary_write);
124  dst_stats.Write(ko.Stream(), binary_write);
125  }
126  KALDI_LOG << "Smoothed stats with tau = " << tau << ", count changed from "
127  << tot_count_before << " to " << tot_count_after;
128  } catch(const std::exception &e) {
129  std::cerr << e.what() << '\n';
130  return -1;
131  }
132 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void IsmoothStatsAmDiagGmmFromModel(const AmDiagGmm &src_model, double tau, AccumAmDiagGmm *dst_stats)
This version of the I-smoothing function takes a model as input.
void Read(std::istream &is, bool binary)
void IsmoothStatsAmDiagGmm(const AccumAmDiagGmm &src_stats, double tau, AccumAmDiagGmm *dst_stats)
Smooth "dst_stats" with "src_stats".
BaseFloat TotStatsCount() const
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void Read(std::istream &in_stream, bool binary, bool add=false)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Write(std::ostream &out_stream, bool binary) const
int main(int argc, char *argv[])
#define KALDI_LOG
Definition: kaldi-error.h:153
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.