cmvn-to-nnet.cc
Go to the documentation of this file.
1 // nnetbin/cmvn-to-nnet.cc
2 
3 // Copyright 2012-2016 Brno University of Technology
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "nnet/nnet-nnet.h"
23 #include "nnet/nnet-various.h"
24 
25 int main(int argc, char *argv[]) {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet1;
29  typedef kaldi::int32 int32;
30 
31  const char *usage =
32  "Convert cmvn-stats into <AddShift> and <Rescale> components.\n"
33  "Usage: cmvn-to-nnet [options] <transf-in> <nnet-out>\n"
34  "e.g.:\n"
35  " cmvn-to-nnet --binary=false transf.mat nnet.mdl\n";
36 
37 
38  bool binary_write = false;
39  float std_dev = 1.0;
40  float var_floor = 1e-10;
41  float learn_rate_coef = 0.0;
42 
43  ParseOptions po(usage);
44  po.Register("binary", &binary_write, "Write output in binary mode");
45  po.Register("std-dev", &std_dev, "Standard deviation of the output.");
46  po.Register("var-floor", &var_floor,
47  "Floor the variance, so the factors in <Rescale> are bounded.");
48  po.Register("learn-rate-coef", &learn_rate_coef,
49  "Initialize learning-rate coefficient to a value.");
50 
51  po.Read(argc, argv);
52 
53  if (po.NumArgs() != 2) {
54  po.PrintUsage();
55  exit(1);
56  }
57 
58  std::string cmvn_stats_rxfilename = po.GetArg(1),
59  model_out_filename = po.GetArg(2);
60 
61  // read the matrix,
62  Matrix<double> cmvn_stats;
63  {
64  bool binary_read;
65  Input ki(cmvn_stats_rxfilename, &binary_read);
66  cmvn_stats.Read(ki.Stream(), binary_read);
67  }
68  KALDI_ASSERT(cmvn_stats.NumRows() == 2);
69  KALDI_ASSERT(cmvn_stats.NumCols() > 1);
70 
71  int32 num_dims = cmvn_stats.NumCols() - 1;
72  double frame_count = cmvn_stats(0, cmvn_stats.NumCols() - 1);
73 
74  // buffers for shift and scale
75  Vector<BaseFloat> shift(num_dims);
76  Vector<BaseFloat> scale(num_dims);
77 
78  // compute the shift and scale per each dimension
79  for (int32 d = 0; d < num_dims; d++) {
80  BaseFloat mean = cmvn_stats(0, d) / frame_count;
81  BaseFloat var = cmvn_stats(1, d) / frame_count - mean * mean;
82  if (var <= var_floor) {
83  KALDI_WARN << "Very small variance " << var
84  << " flooring to " << var_floor;
85  var = var_floor;
86  }
87  shift(d) = -mean;
88  scale(d) = std_dev / sqrt(var);
89  }
90 
91  // create empty nnet,
92  Nnet nnet;
93 
94  // append shift component to nnet,
95  {
96  AddShift shift_component(shift.Dim(), shift.Dim());
97  shift_component.SetParams(shift);
98  shift_component.SetLearnRateCoef(learn_rate_coef);
99  nnet.AppendComponent(shift_component);
100  }
101 
102  // append scale component to nnet,
103  {
104  Rescale scale_component(scale.Dim(), scale.Dim());
105  scale_component.SetParams(scale);
106  scale_component.SetLearnRateCoef(learn_rate_coef);
107  nnet.AppendComponent(scale_component);
108  }
109 
110  // write the nnet,
111  {
112  Output ko(model_out_filename, binary_write);
113  nnet.Write(ko.Stream(), binary_write);
114  KALDI_LOG << "Written cmvn in 'nnet1' model to: " << model_out_filename;
115  }
116  return 0;
117  } catch(const std::exception &e) {
118  std::cerr << e.what();
119  return -1;
120  }
121 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
Definition: cmvn-to-nnet.cc:25
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
kaldi::int32 int32
Rescale the data column-wise by a vector (can be used for global variance normalization) ...
Definition: nnet-various.h:404
void Register(const std::string &name, bool *ptr, const std::string &doc)
void SetParams(const VectorBase< BaseFloat > &params)
Set the trainable parameters from, reshaped as a vector,.
Definition: nnet-various.h:462
std::istream & Stream()
Definition: kaldi-io.cc:826
void Read(std::istream &in, bool binary, bool add=false)
read from stream.
float BaseFloat
Definition: kaldi-types.h:29
Adds shift to all the lines of the matrix (can be used for global mean normalization) ...
Definition: nnet-various.h:291
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void SetParams(const VectorBase< BaseFloat > &params)
Set the trainable parameters from, reshaped as a vector,.
Definition: nnet-various.h:349
#define KALDI_LOG
Definition: kaldi-error.h:153
void AppendComponent(const Component &comp)
Append Component to &#39;this&#39; instance of Nnet (deep copy),.
Definition: nnet-nnet.cc:182