32 std::vector<BaseFloat> *weights) {
34 if (!weights_str.empty()) {
36 if (weights->size() != num_inputs) {
37 KALDI_ERR <<
"--weights option must be a colon-separated list " 38 <<
"with " << num_inputs <<
" elements, got: " 42 for (
int32 i = 0;
i < num_inputs;
i++)
43 weights->push_back(1.0 / num_inputs);
46 float weight_sum = 0.0;
47 for (
int32 i = 0;
i < num_inputs;
i++)
48 weight_sum += (*weights)[
i];
49 for (
int32 i = 0;
i < num_inputs;
i++)
50 (*weights)[
i] = (*weights)[
i] / weight_sum;
51 if (fabs(weight_sum - 1.0) > 0.01) {
52 KALDI_WARN <<
"Normalizing weights to sum to one, sum was " << weight_sum;
59 const int32 first_layer_idx,
60 const int32 last_layer_idx) {
62 std::vector<bool> skip_layers(last_layer_idx,
false);
64 if (skip_layers_str.empty()) {
68 std::vector<int> layer_indices;
71 KALDI_ERR <<
"Cannot parse the skip layers specifier. It should be" 72 <<
"colon-separated list of integers";
75 int min_elem = std::numeric_limits<int>().max(),
76 max_elem = std::numeric_limits<int>().min();
78 std::vector<int>::iterator it;
79 for ( it = layer_indices.begin(); it != layer_indices.end(); ++it ) {
81 *it = last_layer_idx + *it;
92 if (max_elem >= last_layer_idx) {
93 KALDI_ERR <<
"--skip-layers option has to be a colon-separated list" 94 <<
"of indices which are supposed to be skipped.\n" 95 <<
"Maximum expected index: " << last_layer_idx
96 <<
" got: " << max_elem ;
98 if (min_elem < first_layer_idx) {
99 KALDI_ERR <<
"--skip-layers option has to be a colon-separated list" 100 <<
"of indices which are supposed to be skipped.\n" 101 <<
"Minimum expected index: " << first_layer_idx
102 <<
" got: " << min_elem ;
105 for ( it = layer_indices.begin(); it != layer_indices.end(); ++it ) {
106 skip_layers[*it] =
true;
112 int main(
int argc,
char *argv[]) {
114 using namespace kaldi;
118 typedef kaldi::int64 int64;
121 "This program averages (or sums, if --sum=true) the parameters over a\n" 122 "number of neural nets. If you supply the option --skip-last-layer=true,\n" 123 "the parameters of the last updatable layer are copied from <model1> instead\n" 124 "of being averaged (useful in multi-language scenarios).\n" 125 "The --weights option can be used to weight each model differently.\n" 127 "Usage: nnet-am-average [options] <model1> <model2> ... <modelN> <model-out>\n" 130 " nnet-am-average 1.1.nnet 1.2.nnet 1.3.nnet 2.nnet\n";
132 bool binary_write =
true;
136 po.
Register(
"sum", &sum,
"If true, sums instead of averages.");
137 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
139 bool skip_last_layer =
false;
140 string skip_layers_str;
141 po.
Register(
"weights", &weights_str,
"Colon-separated list of weights, one " 142 "for each input model. These will be normalized to sum to one.");
143 po.
Register(
"skip-last-layer", &skip_last_layer,
"If true, averaging of " 144 "the last updatable layer is skipped (result comes from model1)");
145 po.
Register(
"skip-layers", &skip_layers_str,
"Colon-separated list of " 146 "indices of the layers that should be skipped during averaging." 147 "Be careful: this parameter uses an absolute indexing of " 148 "layers, i.e. iterates over all components, not over updatable " 159 nnet1_rxfilename = po.
GetArg(1),
166 Input ki(nnet1_rxfilename, &binary_read);
171 int32 num_inputs = po.
NumArgs() - 1;
173 std::vector<BaseFloat> model_weights;
174 GetWeights(weights_str, num_inputs, &model_weights);
177 c_end = (skip_last_layer ?
180 KALDI_ASSERT(c_end != -1 &&
"Network has no updatable components.");
183 std::vector<bool> skip_layers =
GetSkipLayers(skip_layers_str,
188 for (int32 c = c_begin; c < c_end; c++) {
189 if (skip_layers[c]) {
190 KALDI_VLOG(2) <<
"Not averaging layer " << c <<
" (as requested)";
193 bool updated =
false;
197 KALDI_VLOG(2) <<
"Averaging layer " << c <<
" (UpdatableComponent)";
198 uc->
Scale(model_weights[0]);
204 KALDI_VLOG(2) <<
"Averaging layer " << c <<
" (NonlinearComponent)";
205 nc->
Scale(model_weights[0]);
210 <<
" (unscalable component)";
214 for (int32
i = 2;
i <= num_inputs;
i++) {
218 trans_model.
Read(ki.Stream(), binary_read);
220 am_nnet.
Read(ki.Stream(), binary_read);
222 for (int32 c = c_begin; c < c_end; c++) {
223 if (skip_layers[c])
continue;
229 if (uc_average != NULL) {
231 "Networks must have the same structure.");
232 uc_average->
Add(model_weights[
i-1], *uc_this);
239 if (nc_average != NULL) {
241 "Networks must have the same structure.");
242 nc_average->
Add(model_weights[
i-1], *nc_this);
248 Output ko(nnet_wxfilename, binary_write);
253 KALDI_LOG <<
"Averaged parameters of " << num_inputs
254 <<
" neural nets, and wrote to " << nnet_wxfilename;
256 }
catch(
const std::exception &e) {
257 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
const Component & GetComponent(int32 c) const
This kind of Component is a base-class for things like sigmoid and softmax.
void GetWeights(const std::string &weights_str, int32 num_inputs, std::vector< BaseFloat > *weights)
bool SplitStringToFloats(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< F > *out)
virtual void Scale(BaseFloat scale)=0
This new virtual function scales the parameters by this amount.
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Add(BaseFloat alpha, const NonlinearComponent &other)
void Read(std::istream &is, bool binary)
virtual void Add(BaseFloat alpha, const UpdatableComponent &other)=0
This new virtual function adds the parameters of another updatable component, times some constant...
void Register(const std::string &name, bool *ptr, const std::string &doc)
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
int main(int argc, char *argv[])
void Write(std::ostream &os, bool binary) const
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
std::vector< bool > GetSkipLayers(const std::string &skip_layers_str, const int32 first_layer_idx, const int32 last_layer_idx)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
void Scale(BaseFloat scale)
#define KALDI_ASSERT(cond)
const Nnet & GetNnet() const
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...
int32 LastUpdatableComponent() const
Returns the index of the highest-numbered component which is updatable, or -1 if none are updatable...