nnet-am-average.cc File Reference
#include <algorithm>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "hmm/transition-model.h"
#include "nnet2/combine-nnet-a.h"
#include "nnet2/am-nnet.h"
Include dependency graph for nnet-am-average.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

void GetWeights (const std::string &weights_str, int32 num_inputs, std::vector< BaseFloat > *weights)
 
std::vector< boolGetSkipLayers (const std::string &skip_layers_str, const int32 first_layer_idx, const int32 last_layer_idx)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 112 of file nnet-am-average.cc.

References UpdatableComponent::Add(), NonlinearComponent::Add(), ParseOptions::GetArg(), Nnet::GetComponent(), AmNnet::GetNnet(), kaldi::GetSkipLayers(), kaldi::GetWeights(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_VLOG, Nnet::LastUpdatableComponent(), ParseOptions::NumArgs(), Nnet::NumComponents(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), ParseOptions::Register(), UpdatableComponent::Scale(), NonlinearComponent::Scale(), Output::Stream(), Input::Stream(), AmNnet::Write(), and TransitionModel::Write().

112  {
113  try {
114  using namespace kaldi;
115  using namespace kaldi::nnet2;
116  using std::string;
117  typedef kaldi::int32 int32;
118  typedef kaldi::int64 int64;
119 
120  const char *usage =
121  "This program averages (or sums, if --sum=true) the parameters over a\n"
122  "number of neural nets. If you supply the option --skip-last-layer=true,\n"
123  "the parameters of the last updatable layer are copied from <model1> instead\n"
124  "of being averaged (useful in multi-language scenarios).\n"
125  "The --weights option can be used to weight each model differently.\n"
126  "\n"
127  "Usage: nnet-am-average [options] <model1> <model2> ... <modelN> <model-out>\n"
128  "\n"
129  "e.g.:\n"
130  " nnet-am-average 1.1.nnet 1.2.nnet 1.3.nnet 2.nnet\n";
131 
132  bool binary_write = true;
133  bool sum = false;
134 
135  ParseOptions po(usage);
136  po.Register("sum", &sum, "If true, sums instead of averages.");
137  po.Register("binary", &binary_write, "Write output in binary mode");
138  string weights_str;
139  bool skip_last_layer = false;
140  string skip_layers_str;
141  po.Register("weights", &weights_str, "Colon-separated list of weights, one "
142  "for each input model. These will be normalized to sum to one.");
143  po.Register("skip-last-layer", &skip_last_layer, "If true, averaging of "
144  "the last updatable layer is skipped (result comes from model1)");
145  po.Register("skip-layers", &skip_layers_str, "Colon-separated list of "
146  "indices of the layers that should be skipped during averaging."
147  "Be careful: this parameter uses an absolute indexing of "
148  "layers, i.e. iterates over all components, not over updatable "
149  "ones only.");
150 
151  po.Read(argc, argv);
152 
153  if (po.NumArgs() < 2) {
154  po.PrintUsage();
155  exit(1);
156  }
157 
158  std::string
159  nnet1_rxfilename = po.GetArg(1),
160  nnet_wxfilename = po.GetArg(po.NumArgs());
161 
162  TransitionModel trans_model1;
163  AmNnet am_nnet1;
164  {
165  bool binary_read;
166  Input ki(nnet1_rxfilename, &binary_read);
167  trans_model1.Read(ki.Stream(), binary_read);
168  am_nnet1.Read(ki.Stream(), binary_read);
169  }
170 
171  int32 num_inputs = po.NumArgs() - 1;
172 
173  std::vector<BaseFloat> model_weights;
174  GetWeights(weights_str, num_inputs, &model_weights);
175 
176  int32 c_begin = 0,
177  c_end = (skip_last_layer ?
178  am_nnet1.GetNnet().LastUpdatableComponent() :
179  am_nnet1.GetNnet().NumComponents());
180  KALDI_ASSERT(c_end != -1 && "Network has no updatable components.");
181 
182  int32 last_layer_idx = am_nnet1.GetNnet().NumComponents();
183  std::vector<bool> skip_layers = GetSkipLayers(skip_layers_str,
184  0,
185  last_layer_idx);
186 
187  // scale the components - except the last layer, if skip_last_layer == true.
188  for (int32 c = c_begin; c < c_end; c++) {
189  if (skip_layers[c]) {
190  KALDI_VLOG(2) << "Not averaging layer " << c << " (as requested)";
191  continue;
192  }
193  bool updated = false;
194  UpdatableComponent *uc =
195  dynamic_cast<UpdatableComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
196  if (uc != NULL) {
197  KALDI_VLOG(2) << "Averaging layer " << c << " (UpdatableComponent)";
198  uc->Scale(model_weights[0]);
199  updated = true;
200  }
201  NonlinearComponent *nc =
202  dynamic_cast<NonlinearComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
203  if (nc != NULL) {
204  KALDI_VLOG(2) << "Averaging layer " << c << " (NonlinearComponent)";
205  nc->Scale(model_weights[0]);
206  updated = true;
207  }
208  if (! updated) {
209  KALDI_VLOG(2) << "Not averaging layer " << c
210  << " (unscalable component)";
211  }
212  }
213 
214  for (int32 i = 2; i <= num_inputs; i++) {
215  bool binary_read;
216  Input ki(po.GetArg(i), &binary_read);
217  TransitionModel trans_model;
218  trans_model.Read(ki.Stream(), binary_read);
219  AmNnet am_nnet;
220  am_nnet.Read(ki.Stream(), binary_read);
221 
222  for (int32 c = c_begin; c < c_end; c++) {
223  if (skip_layers[c]) continue;
224 
225  UpdatableComponent *uc_average =
226  dynamic_cast<UpdatableComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
227  const UpdatableComponent *uc_this =
228  dynamic_cast<const UpdatableComponent*>(&(am_nnet.GetNnet().GetComponent(c)));
229  if (uc_average != NULL) {
230  KALDI_ASSERT(uc_this != NULL &&
231  "Networks must have the same structure.");
232  uc_average->Add(model_weights[i-1], *uc_this);
233  }
234 
235  NonlinearComponent *nc_average =
236  dynamic_cast<NonlinearComponent*>(&(am_nnet1.GetNnet().GetComponent(c)));
237  const NonlinearComponent *nc_this =
238  dynamic_cast<const NonlinearComponent*>(&(am_nnet.GetNnet().GetComponent(c)));
239  if (nc_average != NULL) {
240  KALDI_ASSERT(nc_this != NULL &&
241  "Networks must have the same structure.");
242  nc_average->Add(model_weights[i-1], *nc_this);
243  }
244  }
245  }
246 
247  {
248  Output ko(nnet_wxfilename, binary_write);
249  trans_model1.Write(ko.Stream(), binary_write);
250  am_nnet1.Write(ko.Stream(), binary_write);
251  }
252 
253  KALDI_LOG << "Averaged parameters of " << num_inputs
254  << " neural nets, and wrote to " << nnet_wxfilename;
255  return 0; // it will throw an exception if there are any problems.
256  } catch(const std::exception &e) {
257  std::cerr << e.what() << '\n';
258  return -1;
259  }
260 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
This kind of Component is a base-class for things like sigmoid and softmax.
void GetWeights(const std::string &weights_str, int32 num_inputs, std::vector< BaseFloat > *weights)
virtual void Scale(BaseFloat scale)=0
This new virtual function scales the parameters by this amount.
void Add(BaseFloat alpha, const NonlinearComponent &other)
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
virtual void Add(BaseFloat alpha, const UpdatableComponent &other)=0
This new virtual function adds the parameters of another updatable component, times some constant...
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
std::vector< bool > GetSkipLayers(const std::string &skip_layers_str, const int32 first_layer_idx, const int32 last_layer_idx)
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...
int32 LastUpdatableComponent() const
Returns the index of the highest-numbered component which is updatable, or -1 if none are updatable...
Definition: nnet-nnet.cc:837