114 using namespace kaldi;
118 typedef kaldi::int64 int64;
121 "This program averages (or sums, if --sum=true) the parameters over a\n" 122 "number of neural nets. If you supply the option --skip-last-layer=true,\n" 123 "the parameters of the last updatable layer are copied from <model1> instead\n" 124 "of being averaged (useful in multi-language scenarios).\n" 125 "The --weights option can be used to weight each model differently.\n" 127 "Usage: nnet-am-average [options] <model1> <model2> ... <modelN> <model-out>\n" 130 " nnet-am-average 1.1.nnet 1.2.nnet 1.3.nnet 2.nnet\n";
132 bool binary_write =
true;
136 po.Register(
"sum", &sum,
"If true, sums instead of averages.");
137 po.Register(
"binary", &binary_write,
"Write output in binary mode");
139 bool skip_last_layer =
false;
140 string skip_layers_str;
141 po.Register(
"weights", &weights_str,
"Colon-separated list of weights, one " 142 "for each input model. These will be normalized to sum to one.");
143 po.Register(
"skip-last-layer", &skip_last_layer,
"If true, averaging of " 144 "the last updatable layer is skipped (result comes from model1)");
145 po.Register(
"skip-layers", &skip_layers_str,
"Colon-separated list of " 146 "indices of the layers that should be skipped during averaging." 147 "Be careful: this parameter uses an absolute indexing of " 148 "layers, i.e. iterates over all components, not over updatable " 153 if (po.NumArgs() < 2) {
159 nnet1_rxfilename = po.GetArg(1),
160 nnet_wxfilename = po.GetArg(po.NumArgs());
166 Input ki(nnet1_rxfilename, &binary_read);
167 trans_model1.
Read(ki.Stream(), binary_read);
168 am_nnet1.
Read(ki.Stream(), binary_read);
171 int32 num_inputs = po.NumArgs() - 1;
173 std::vector<BaseFloat> model_weights;
174 GetWeights(weights_str, num_inputs, &model_weights);
177 c_end = (skip_last_layer ?
180 KALDI_ASSERT(c_end != -1 &&
"Network has no updatable components.");
183 std::vector<bool> skip_layers =
GetSkipLayers(skip_layers_str,
188 for (int32 c = c_begin; c < c_end; c++) {
189 if (skip_layers[c]) {
190 KALDI_VLOG(2) <<
"Not averaging layer " << c <<
" (as requested)";
193 bool updated =
false;
197 KALDI_VLOG(2) <<
"Averaging layer " << c <<
" (UpdatableComponent)";
198 uc->
Scale(model_weights[0]);
204 KALDI_VLOG(2) <<
"Averaging layer " << c <<
" (NonlinearComponent)";
205 nc->
Scale(model_weights[0]);
210 <<
" (unscalable component)";
214 for (int32
i = 2;
i <= num_inputs;
i++) {
216 Input ki(po.GetArg(
i), &binary_read);
218 trans_model.
Read(ki.Stream(), binary_read);
220 am_nnet.
Read(ki.Stream(), binary_read);
222 for (int32 c = c_begin; c < c_end; c++) {
223 if (skip_layers[c])
continue;
229 if (uc_average != NULL) {
231 "Networks must have the same structure.");
232 uc_average->
Add(model_weights[
i-1], *uc_this);
239 if (nc_average != NULL) {
241 "Networks must have the same structure.");
242 nc_average->
Add(model_weights[
i-1], *nc_this);
248 Output ko(nnet_wxfilename, binary_write);
249 trans_model1.
Write(ko.Stream(), binary_write);
250 am_nnet1.
Write(ko.Stream(), binary_write);
253 KALDI_LOG <<
"Averaged parameters of " << num_inputs
254 <<
" neural nets, and wrote to " << nnet_wxfilename;
256 }
catch(
const std::exception &e) {
257 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
const Component & GetComponent(int32 c) const
This kind of Component is a base-class for things like sigmoid and softmax.
void GetWeights(const std::string &weights_str, int32 num_inputs, std::vector< BaseFloat > *weights)
virtual void Scale(BaseFloat scale)=0
This new virtual function scales the parameters by this amount.
void Add(BaseFloat alpha, const NonlinearComponent &other)
void Read(std::istream &is, bool binary)
virtual void Add(BaseFloat alpha, const UpdatableComponent &other)=0
This new virtual function adds the parameters of another updatable component, times some constant...
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
std::vector< bool > GetSkipLayers(const std::string &skip_layers_str, const int32 first_layer_idx, const int32 last_layer_idx)
void Write(std::ostream &os, bool binary) const
void Scale(BaseFloat scale)
#define KALDI_ASSERT(cond)
const Nnet & GetNnet() const
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...
int32 LastUpdatableComponent() const
Returns the index of the highest-numbered component which is updatable, or -1 if none are updatable...