cmvn-to-nnet.cc File Reference
Include dependency graph for cmvn-to-nnet.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 25 of file cmvn-to-nnet.cc.

References Nnet::AppendComponent(), rnnlm::d, VectorBase< Real >::Dim(), ParseOptions::GetArg(), KALDI_ASSERT, KALDI_LOG, KALDI_WARN, ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), Matrix< Real >::Read(), ParseOptions::Register(), AddShift::SetParams(), Rescale::SetParams(), Output::Stream(), Input::Stream(), and Nnet::Write().

25  {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet1;
29  typedef kaldi::int32 int32;
30 
31  const char *usage =
32  "Convert cmvn-stats into <AddShift> and <Rescale> components.\n"
33  "Usage: cmvn-to-nnet [options] <transf-in> <nnet-out>\n"
34  "e.g.:\n"
35  " cmvn-to-nnet --binary=false transf.mat nnet.mdl\n";
36 
37 
38  bool binary_write = false;
39  float std_dev = 1.0;
40  float var_floor = 1e-10;
41  float learn_rate_coef = 0.0;
42 
43  ParseOptions po(usage);
44  po.Register("binary", &binary_write, "Write output in binary mode");
45  po.Register("std-dev", &std_dev, "Standard deviation of the output.");
46  po.Register("var-floor", &var_floor,
47  "Floor the variance, so the factors in <Rescale> are bounded.");
48  po.Register("learn-rate-coef", &learn_rate_coef,
49  "Initialize learning-rate coefficient to a value.");
50 
51  po.Read(argc, argv);
52 
53  if (po.NumArgs() != 2) {
54  po.PrintUsage();
55  exit(1);
56  }
57 
58  std::string cmvn_stats_rxfilename = po.GetArg(1),
59  model_out_filename = po.GetArg(2);
60 
61  // read the matrix,
62  Matrix<double> cmvn_stats;
63  {
64  bool binary_read;
65  Input ki(cmvn_stats_rxfilename, &binary_read);
66  cmvn_stats.Read(ki.Stream(), binary_read);
67  }
68  KALDI_ASSERT(cmvn_stats.NumRows() == 2);
69  KALDI_ASSERT(cmvn_stats.NumCols() > 1);
70 
71  int32 num_dims = cmvn_stats.NumCols() - 1;
72  double frame_count = cmvn_stats(0, cmvn_stats.NumCols() - 1);
73 
74  // buffers for shift and scale
75  Vector<BaseFloat> shift(num_dims);
76  Vector<BaseFloat> scale(num_dims);
77 
78  // compute the shift and scale per each dimension
79  for (int32 d = 0; d < num_dims; d++) {
80  BaseFloat mean = cmvn_stats(0, d) / frame_count;
81  BaseFloat var = cmvn_stats(1, d) / frame_count - mean * mean;
82  if (var <= var_floor) {
83  KALDI_WARN << "Very small variance " << var
84  << " flooring to " << var_floor;
85  var = var_floor;
86  }
87  shift(d) = -mean;
88  scale(d) = std_dev / sqrt(var);
89  }
90 
91  // create empty nnet,
92  Nnet nnet;
93 
94  // append shift component to nnet,
95  {
96  AddShift shift_component(shift.Dim(), shift.Dim());
97  shift_component.SetParams(shift);
98  shift_component.SetLearnRateCoef(learn_rate_coef);
99  nnet.AppendComponent(shift_component);
100  }
101 
102  // append scale component to nnet,
103  {
104  Rescale scale_component(scale.Dim(), scale.Dim());
105  scale_component.SetParams(scale);
106  scale_component.SetLearnRateCoef(learn_rate_coef);
107  nnet.AppendComponent(scale_component);
108  }
109 
110  // write the nnet,
111  {
112  Output ko(model_out_filename, binary_write);
113  nnet.Write(ko.Stream(), binary_write);
114  KALDI_LOG << "Written cmvn in 'nnet1' model to: " << model_out_filename;
115  }
116  return 0;
117  } catch(const std::exception &e) {
118  std::cerr << e.what();
119  return -1;
120  }
121 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
kaldi::int32 int32
Rescale the data column-wise by a vector (can be used for global variance normalization) ...
Definition: nnet-various.h:404
void SetParams(const VectorBase< BaseFloat > &params)
Set the trainable parameters from, reshaped as a vector,.
Definition: nnet-various.h:462
void Read(std::istream &in, bool binary, bool add=false)
read from stream.
float BaseFloat
Definition: kaldi-types.h:29
Adds shift to all the lines of the matrix (can be used for global mean normalization) ...
Definition: nnet-various.h:291
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
#define KALDI_WARN
Definition: kaldi-error.h:150
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void SetParams(const VectorBase< BaseFloat > &params)
Set the trainable parameters from, reshaped as a vector,.
Definition: nnet-various.h:349
#define KALDI_LOG
Definition: kaldi-error.h:153
void AppendComponent(const Component &comp)
Append Component to &#39;this&#39; instance of Nnet (deep copy),.
Definition: nnet-nnet.cc:182