nnet-normalize-stddev.cc File Reference
Include dependency graph for nnet-normalize-stddev.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file nnet-normalize-stddev.cc.

References ParseOptions::GetArg(), Nnet::GetComponent(), AmNnet::GetNnet(), UpdatableComponent::GetParameterDim(), KALDI_ASSERT, KALDI_LOG, ParseOptions::NumArgs(), Nnet::NumComponents(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), ParseOptions::Register(), UpdatableComponent::Scale(), Output::Stream(), Input::Stream(), UpdatableComponent::Vectorize(), kaldi::VecVec(), AmNnet::Write(), and TransitionModel::Write().

28  {
29  try {
30  using namespace kaldi;
31  using namespace kaldi::nnet2;
32  typedef kaldi::int32 int32;
33  typedef kaldi::int64 int64;
34 
35  const char *usage =
36  "This program first identifies any affine or block affine layers that\n"
37  "are followed by pnorm and then renormalize layers. Then it rescales\n"
38  "those layers such that the parameter stddev is 1.0 after scaling\n"
39  "(the target stddev is configurable by the --stddev option).\n"
40  "If you supply the option --stddev-from=<model-filename>, it rescales\n"
41  "those layers to match the standard deviations of corresponding layers\n"
42  "in the specified model.\n"
43  "\n"
44  "Usage: nnet-normalize-stddev [options] <model-in> <model-out>\n"
45  " e.g.: nnet-normalize-stddev final.mdl final.mdl\n";
46 
47  bool binary_write = true;
48  BaseFloat stddev = 1.0;
49  std::string reference_model_filename;
50 
51  ParseOptions po(usage);
52  po.Register("binary", &binary_write, "Write output in binary mode");
53  po.Register("stddev-from", &reference_model_filename, "Reference model");
54  po.Register("stddev", &stddev, "Target standard deviation that we normalize "
55  "to (note: is overridden by --stddev-from option, if supplied)");
56 
57  po.Read(argc, argv);
58 
59  if (po.NumArgs() != 2) {
60  po.PrintUsage();
61  exit(1);
62  }
63 
64  std::string nnet_rxfilename = po.GetArg(1),
65  normalized_nnet_rxfilename = po.GetArg(2);
66 
67  TransitionModel trans_model;
68  AmNnet am_nnet;
69  {
70  bool binary_read;
71  Input ki(nnet_rxfilename, &binary_read);
72  trans_model.Read(ki.Stream(), binary_read);
73  am_nnet.Read(ki.Stream(), binary_read);
74  }
75 
76  int32 ret = 0;
77 
78  // Works out the layers that we would like to normalize: any affine or block
79  // affine layers that are followed by pnorm and then renormalize layers.
80  std::vector<int32> identified_components;
81  for (int32 c = 0; c < am_nnet.GetNnet().NumComponents() - 2; c++) {
82  // Checks if the current layer is an affine layer or block affine layer.
83  // Also includes PreconditionedAffineComponent and
84  // PreconditionedAffineComponentOnline, since they are child classes of
85  // AffineComponent.
86  kaldi::nnet2::Component *component = &(am_nnet.GetNnet().GetComponent(c));
87  AffineComponent *ac = dynamic_cast<AffineComponent*>(component);
89  dynamic_cast<BlockAffineComponent*>(component);
90  if (ac == NULL && bac == NULL)
91  continue;
92 
93  // Checks if the next layer is a pnorm layer.
94  component = &(am_nnet.GetNnet().GetComponent(c + 1));
95  PnormComponent *pc = dynamic_cast<PnormComponent*>(component);
96  if (pc == NULL)
97  continue;
98 
99  // Checks if the layer after the pnorm layer is a NormalizeComponent
100  // or a PowerComponent followed by a NormalizeComponent
101  component = &(am_nnet.GetNnet().GetComponent(c + 2));
102  NormalizeComponent *nc = dynamic_cast<NormalizeComponent*>(component);
103  PowerComponent *pwc = dynamic_cast<PowerComponent*>(component);
104  if (nc == NULL && pwc == NULL)
105  continue;
106  if (pwc != NULL) { // verify it's PowerComponent followed by
107  // NormalizeComponent.
108  if (c + 3 >= am_nnet.GetNnet().NumComponents())
109  continue;
110  component = &(am_nnet.GetNnet().GetComponent(c + 3));
111  nc = dynamic_cast<NormalizeComponent*>(component);
112  if (nc == NULL)
113  continue;
114  }
115  // This is the layer that we would like to normalize.
116  identified_components.push_back(c);
117  }
118 
119  AmNnet am_nnet_ref;
120  if (!reference_model_filename.empty()) {
121  bool binary_read;
122  Input ki(reference_model_filename, &binary_read);
123  trans_model.Read(ki.Stream(), binary_read);
124  am_nnet_ref.Read(ki.Stream(), binary_read);
125  KALDI_ASSERT(am_nnet_ref.GetNnet().NumComponents() == am_nnet.GetNnet().NumComponents());
126  }
127 
128  BaseFloat ref_stddev = 0.0;
129 
130  // Normalizes the identified layers.
131  for (int32 c = 0; c < identified_components.size(); c++) {
132  ref_stddev = stddev;
133  if (!reference_model_filename.empty()) {
134  kaldi::nnet2::Component *component =
135  &(am_nnet_ref.GetNnet().GetComponent(identified_components[c]));
136  UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(component);
137  KALDI_ASSERT(uc != NULL);
138  Vector<BaseFloat> params(uc->GetParameterDim());
139  uc->Vectorize(&params);
140  BaseFloat params_average = params.Sum()
141  / static_cast<BaseFloat>(params.Dim());
142  params.Add(-1.0 * params_average);
143  ref_stddev = sqrt(VecVec(params, params)
144  / static_cast<BaseFloat>(params.Dim()));
145  }
146 
147  kaldi::nnet2::Component *component =
148  &(am_nnet.GetNnet().GetComponent(identified_components[c]));
149  UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(component);
150  KALDI_ASSERT(uc != NULL);
151  Vector<BaseFloat> params(uc->GetParameterDim());
152  uc->Vectorize(&params);
153  BaseFloat params_average = params.Sum()
154  / static_cast<BaseFloat>(params.Dim());
155  params.Add(-1.0 * params_average);
156  BaseFloat params_stddev = sqrt(VecVec(params, params)
157  / static_cast<BaseFloat>(params.Dim()));
158  if (params_stddev > 0.0) {
159  uc->Scale(ref_stddev / params_stddev);
160  KALDI_LOG << "Normalized component " << identified_components[c];
161  }
162  }
163 
164  // Writes the normalized model.
165  Output ko(normalized_nnet_rxfilename, binary_write);
166  trans_model.Write(ko.Stream(), binary_write);
167  am_nnet.Write(ko.Stream(), binary_write);
168 
169  return ret;
170  } catch(const std::exception &e) {
171  std::cerr << e.what() << '\n';
172  return -1;
173  }
174 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
virtual void Scale(BaseFloat scale)=0
This new virtual function scales the parameters by this amount.
Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
virtual void Vectorize(VectorBase< BaseFloat > *params) const
Turns the parameters into vector form.
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
Take the absoute values of an input vector to a power.
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
virtual int32 GetParameterDim() const
The following new virtual function returns the total dimension of the parameters in this class...
void Write(std::ostream &os, bool binary) const
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
const Nnet & GetNnet() const
Definition: am-nnet.h:61
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...