nnet-am-average.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-am-average.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <algorithm>
21 
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "hmm/transition-model.h"
25 #include "nnet2/combine-nnet-a.h"
26 #include "nnet2/am-nnet.h"
27 
28 namespace kaldi {
29 
30 void GetWeights(const std::string &weights_str,
31  int32 num_inputs,
32  std::vector<BaseFloat> *weights) {
33  KALDI_ASSERT(num_inputs >= 1);
34  if (!weights_str.empty()) {
35  SplitStringToFloats(weights_str, ":", true, weights);
36  if (weights->size() != num_inputs) {
37  KALDI_ERR << "--weights option must be a colon-separated list "
38  << "with " << num_inputs << " elements, got: "
39  << weights_str;
40  }
41  } else {
42  for (int32 i = 0; i < num_inputs; i++)
43  weights->push_back(1.0 / num_inputs);
44  }
45  // normalize the weights to sum to one.
46  float weight_sum = 0.0;
47  for (int32 i = 0; i < num_inputs; i++)
48  weight_sum += (*weights)[i];
49  for (int32 i = 0; i < num_inputs; i++)
50  (*weights)[i] = (*weights)[i] / weight_sum;
51  if (fabs(weight_sum - 1.0) > 0.01) {
52  KALDI_WARN << "Normalizing weights to sum to one, sum was " << weight_sum;
53  }
54 }
55 
56 
57 
58 std::vector<bool> GetSkipLayers(const std::string &skip_layers_str,
59  const int32 first_layer_idx,
60  const int32 last_layer_idx) {
61 
62  std::vector<bool> skip_layers(last_layer_idx, false);
63 
64  if (skip_layers_str.empty()) {
65  return skip_layers;
66  }
67 
68  std::vector<int> layer_indices;
69  bool ret = SplitStringToIntegers(skip_layers_str, ":", true, &layer_indices);
70  if (!ret) {
71  KALDI_ERR << "Cannot parse the skip layers specifier. It should be"
72  << "colon-separated list of integers";
73  }
74 
75  int min_elem = std::numeric_limits<int>().max(),
76  max_elem = std::numeric_limits<int>().min();
77 
78  std::vector<int>::iterator it;
79  for ( it = layer_indices.begin(); it != layer_indices.end(); ++it ) {
80  if ( *it < 0 )
81  *it = last_layer_idx + *it; // convert the negative indices to
82  // correct indices -- -1 would be the
83  // last one, -2 the one before the last
84  // and so on.
85  if (*it > max_elem)
86  max_elem = *it;
87 
88  if (*it < min_elem)
89  min_elem = *it;
90  }
91 
92  if (max_elem >= last_layer_idx) {
93  KALDI_ERR << "--skip-layers option has to be a colon-separated list"
94  << "of indices which are supposed to be skipped.\n"
95  << "Maximum expected index: " << last_layer_idx
96  << " got: " << max_elem ;
97  }
98  if (min_elem < first_layer_idx) {
99  KALDI_ERR << "--skip-layers option has to be a colon-separated list"
100  << "of indices which are supposed to be skipped.\n"
101  << "Minimum expected index: " << first_layer_idx
102  << " got: " << min_elem ;
103  }
104 
105  for ( it = layer_indices.begin(); it != layer_indices.end(); ++it ) {
106  skip_layers[*it] = true;
107  }
108  return skip_layers;
109 }
110 
111 }
112 int main(int argc, char *argv[]) {
113  try {
114  using namespace kaldi;
115  using namespace kaldi::nnet2;
116  using std::string;
117  typedef kaldi::int32 int32;
118  typedef kaldi::int64 int64;
119 
120  const char *usage =
121  "This program averages (or sums, if --sum=true) the parameters over a\n"
122  "number of neural nets. If you supply the option --skip-last-layer=true,\n"
123  "the parameters of the last updatable layer are copied from <model1> instead\n"
124  "of being averaged (useful in multi-language scenarios).\n"
125  "The --weights option can be used to weight each model differently.\n"
126  "\n"
127  "Usage: nnet-am-average [options] <model1> <model2> ... <modelN> <model-out>\n"
128  "\n"
129  "e.g.:\n"
130  " nnet-am-average 1.1.nnet 1.2.nnet 1.3.nnet 2.nnet\n";
131 
132  bool binary_write = true;
133  bool sum = false;
134 
135  ParseOptions po(usage);
136  po.Register("sum", &sum, "If true, sums instead of averages.");
137  po.Register("binary", &binary_write, "Write output in binary mode");
138  string weights_str;
139  bool skip_last_layer = false;
140  string skip_layers_str;
141  po.Register("weights", &weights_str, "Colon-separated list of weights, one "
142  "for each input model. These will be normalized to sum to one.");
143  po.Register("skip-last-layer", &skip_last_layer, "If true, averaging of "
144  "the last updatable layer is skipped (result comes from model1)");
145  po.Register("skip-layers", &skip_layers_str, "Colon-separated list of "
146  "indices of the layers that should be skipped during averaging."
147  "Be careful: this parameter uses an absolute indexing of "
148  "layers, i.e. iterates over all components, not over updatable "
149  "ones only.");
150 
151  po.Read(argc, argv);
152 
153  if (po.NumArgs() < 2) {
154  po.PrintUsage();
155  exit(1);
156  }
157 
158  std::string
159  nnet1_rxfilename = po.GetArg(1),
160  nnet_wxfilename = po.GetArg(po.NumArgs());
161 
162  TransitionModel trans_model1;
163  AmNnet am_nnet1;
164  {
165  bool binary_read;
166  Input ki(nnet1_rxfilename, &binary_read);
167  trans_model1.Read(ki.Stream(), binary_read);
168  am_nnet1.Read(ki.Stream(), binary_read);
169  }
170 
171  int32 num_inputs = po.NumArgs() - 1;
172 
173  std::vector<BaseFloat> model_weights;
174  GetWeights(weights_str, num_inputs, &model_weights);
175 
176  int32 c_begin = 0,
177  c_end = (skip_last_layer ?
178  am_nnet1.GetNnet().LastUpdatableComponent() :
179  am_nnet1.GetNnet().NumComponents());
180  KALDI_ASSERT(c_end != -1 && "Network has no updatable components.");
181 
182  int32 last_layer_idx = am_nnet1.GetNnet().NumComponents();
183  std::vector<bool> skip_layers = GetSkipLayers(skip_layers_str,
184  0,
185  last_layer_idx);
186 
187  // scale the components - except the last layer, if skip_last_layer == true.
188  for (int32 c = c_begin; c < c_end; c++) {
189  if (skip_layers[c]) {
190  KALDI_VLOG(2) << "Not averaging layer " << c << " (as requested)";
191  continue;
192  }
193  bool updated = false;
194  UpdatableComponent *uc =
195  dynamic_cast<UpdatableComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
196  if (uc != NULL) {
197  KALDI_VLOG(2) << "Averaging layer " << c << " (UpdatableComponent)";
198  uc->Scale(model_weights[0]);
199  updated = true;
200  }
201  NonlinearComponent *nc =
202  dynamic_cast<NonlinearComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
203  if (nc != NULL) {
204  KALDI_VLOG(2) << "Averaging layer " << c << " (NonlinearComponent)";
205  nc->Scale(model_weights[0]);
206  updated = true;
207  }
208  if (! updated) {
209  KALDI_VLOG(2) << "Not averaging layer " << c
210  << " (unscalable component)";
211  }
212  }
213 
214  for (int32 i = 2; i <= num_inputs; i++) {
215  bool binary_read;
216  Input ki(po.GetArg(i), &binary_read);
217  TransitionModel trans_model;
218  trans_model.Read(ki.Stream(), binary_read);
219  AmNnet am_nnet;
220  am_nnet.Read(ki.Stream(), binary_read);
221 
222  for (int32 c = c_begin; c < c_end; c++) {
223  if (skip_layers[c]) continue;
224 
225  UpdatableComponent *uc_average =
226  dynamic_cast<UpdatableComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
227  const UpdatableComponent *uc_this =
228  dynamic_cast<const UpdatableComponent*>(&(am_nnet.GetNnet().GetComponent(c)));
229  if (uc_average != NULL) {
230  KALDI_ASSERT(uc_this != NULL &&
231  "Networks must have the same structure.");
232  uc_average->Add(model_weights[i-1], *uc_this);
233  }
234 
235  NonlinearComponent *nc_average =
236  dynamic_cast<NonlinearComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
237  const NonlinearComponent *nc_this =
238  dynamic_cast<const NonlinearComponent*>(&(am_nnet.GetNnet().GetComponent(c)));
239  if (nc_average != NULL) {
240  KALDI_ASSERT(nc_this != NULL &&
241  "Networks must have the same structure.");
242  nc_average->Add(model_weights[i-1], *nc_this);
243  }
244  }
245  }
246 
247  {
248  Output ko(nnet_wxfilename, binary_write);
249  trans_model1.Write(ko.Stream(), binary_write);
250  am_nnet1.Write(ko.Stream(), binary_write);
251  }
252 
253  KALDI_LOG << "Averaged parameters of " << num_inputs
254  << " neural nets, and wrote to " << nnet_wxfilename;
255  return 0; // it will throw an exception if there are any problems.
256  } catch(const std::exception &e) {
257  std::cerr << e.what() << '\n';
258  return -1;
259  }
260 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
This kind of Component is a base-class for things like sigmoid and softmax.
void GetWeights(const std::string &weights_str, int32 num_inputs, std::vector< BaseFloat > *weights)
bool SplitStringToFloats(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< F > *out)
Definition: text-utils.cc:30
virtual void Scale(BaseFloat scale)=0
This new virtual function scales the parameters by this amount.
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Add(BaseFloat alpha, const NonlinearComponent &other)
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
virtual void Add(BaseFloat alpha, const UpdatableComponent &other)=0
This new virtual function adds the parameters of another updatable component, times some constant...
void Register(const std::string &name, bool *ptr, const std::string &doc)
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
int main(int argc, char *argv[])
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< bool > GetSkipLayers(const std::string &skip_layers_str, const int32 first_layer_idx, const int32 last_layer_idx)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...
int32 LastUpdatableComponent() const
Returns the index of the highest-numbered component which is updatable, or -1 if none are updatable...
Definition: nnet-nnet.cc:837