nnet-am-copy.cc File Reference
#include <typeinfo>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "hmm/transition-model.h"
#include "nnet2/am-nnet.h"
#include "tree/context-dep.h"
Include dependency graph for nnet-am-copy.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file nnet-am-copy.cc.

References Nnet::Collapse(), kaldi::ConvertStringToReal(), Nnet::CopyStatsFrom(), ParseOptions::GetArg(), AmNnet::GetNnet(), KALDI_ERR, KALDI_LOG, ParseOptions::NumArgs(), Nnet::NumUpdatableComponents(), Nnet::OutputDim(), ParseOptions::PrintUsage(), AmNnet::Priors(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), ParseOptions::Register(), Nnet::RemoveDropout(), Nnet::RemovePreconditioning(), Nnet::Resize(), Nnet::ScaleComponents(), Nnet::ScaleLearningRates(), Nnet::SetDropoutScale(), Nnet::SetLearningRates(), AmNnet::SetPriors(), kaldi::SplitStringToFloats(), kaldi::SplitStringToVector(), Output::Stream(), Input::Stream(), AmNnet::Write(), and TransitionModel::Write().

27  {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Copy a (nnet2) neural net and its associated transition model,\n"
35  "possibly changing the binary mode\n"
36  "Also supports multiplying all the learning rates by a factor\n"
37  "(the --learning-rate-factor option) and setting them all to a given\n"
38  "value (the --learning-rate options)\n"
39  "\n"
40  "Usage: nnet-am-copy [options] <nnet-in> <nnet-out>\n"
41  "e.g.:\n"
42  " nnet-am-copy --binary=false 1.mdl text.mdl\n";
43 
44  int32 truncate = -1;
45  bool binary_write = true;
46  bool remove_dropout = false;
47  BaseFloat dropout_scale = -1.0;
48  bool remove_preconditioning = false;
49  bool collapse = false;
50  bool match_updatableness = true;
51  BaseFloat learning_rate_factor = 1.0, learning_rate = -1;
52  std::string learning_rate_scales_str = " ";
53  std::string learning_rates = "";
54  std::string scales = "";
55  std::string stats_from;
56 
57  ParseOptions po(usage);
58  po.Register("binary", &binary_write, "Write output in binary mode");
59  po.Register("learning-rate-factor", &learning_rate_factor,
60  "Before copying, multiply all the learning rates in the "
61  "model by this factor.");
62  po.Register("learning-rate", &learning_rate,
63  "If supplied, all the learning rates of \"updatable\" layers"
64  "are set to this value.");
65  po.Register("learning-rates", &learning_rates,
66  "If supplied (a colon-separated list of learning rates), sets "
67  "the learning rates of \"updatable\" layers to these values.");
68  po.Register("scales", &scales,
69  "A colon-separated list of scaling factors, one for each updatable "
70  "layer: a mechanism to scale the parameters.");
71  po.Register("learning-rate-scales", &learning_rate_scales_str,
72  "Colon-separated list of scaling factors for learning rates, "
73  "applied after the --learning-rate and --learning-rates options."
74  "Used to scale learning rates for particular layer types. E.g."
75  "--learning-rate-scales=AffineComponent=0.5");
76  po.Register("truncate", &truncate, "If set, will truncate the neural net "
77  "to this many components by removing the last components.");
78  po.Register("remove-dropout", &remove_dropout, "Set this to true to remove "
79  "any dropout components.");
80  po.Register("dropout-scale", &dropout_scale, "If set, set the dropout scale in any "
81  "dropout components to this value. Note: in traditional dropout, this "
82  "is always zero; you can set it to any value between zero and one.");
83  po.Register("remove-preconditioning", &remove_preconditioning, "Set this to true to replace "
84  "components of type AffineComponentPreconditioned with AffineComponent.");
85  po.Register("stats-from", &stats_from, "Before copying neural net, copy the "
86  "statistics in any layer of type NonlinearComponent, from this "
87  "neural network: provide the extended filename.");
88  po.Register("collapse", &collapse, "If true, collapse sequences of AffineComponents "
89  "and FixedAffineComponents to compactify model");
90  po.Register("match-updatableness", &match_updatableness, "Only relevant if "
91  "collapse=true; set this to false to collapse mixed types.");
92 
93  po.Read(argc, argv);
94 
95  if (po.NumArgs() != 2) {
96  po.PrintUsage();
97  exit(1);
98  }
99 
100  std::string nnet_rxfilename = po.GetArg(1),
101  nnet_wxfilename = po.GetArg(2);
102 
103  TransitionModel trans_model;
104  AmNnet am_nnet;
105  {
106  bool binary;
107  Input ki(nnet_rxfilename, &binary);
108  trans_model.Read(ki.Stream(), binary);
109  am_nnet.Read(ki.Stream(), binary);
110  }
111 
112  if (learning_rate_factor != 1.0)
113  am_nnet.GetNnet().ScaleLearningRates(learning_rate_factor);
114 
115  if (learning_rate >= 0)
116  am_nnet.GetNnet().SetLearningRates(learning_rate);
117 
118  if (learning_rates != "") {
119  std::vector<BaseFloat> learning_rates_vec;
120  if (!SplitStringToFloats(learning_rates, ":", false, &learning_rates_vec)
121  || static_cast<int32>(learning_rates_vec.size()) !=
122  am_nnet.GetNnet().NumUpdatableComponents()) {
123  KALDI_ERR << "Expected --learning-rates option to be a "
124  << "colon-separated string with "
125  << am_nnet.GetNnet().NumUpdatableComponents()
126  << " elements, instead got \"" << learning_rates << '"';
127  }
128  SubVector<BaseFloat> learning_rates_vector(&(learning_rates_vec[0]),
129  learning_rates_vec.size());
130  am_nnet.GetNnet().SetLearningRates(learning_rates_vector);
131  }
132 
133  if (learning_rate_scales_str != " ") {
134  // parse the learning_rate_scales provided as an option
135  std::map<std::string, BaseFloat> learning_rate_scales;
136  std::vector<std::string> learning_rate_scale_vec;
137  SplitStringToVector(learning_rate_scales_str, ":", true,
138  &learning_rate_scale_vec);
139  for (int32 index = 0; index < learning_rate_scale_vec.size();
140  index++) {
141  std::vector<std::string> parts;
142  BaseFloat scale_factor;
143  SplitStringToVector(learning_rate_scale_vec[index],
144  "=", false, &parts);
145  if (!ConvertStringToReal(parts[1], &scale_factor)) {
146  KALDI_ERR << "Unknown format for --learning-rate-scales option. "
147  << "Expected format is "
148  << "--learning-rate-scales=AffineComponent=0.1:AffineComponentPreconditioned=0.5 "
149  << "instead got "
150  << learning_rate_scales_str;
151  }
152  learning_rate_scales.insert(std::pair<std::string, BaseFloat>(
153  parts[0], scale_factor));
154  }
155  // use the learning_rate_scales to scale the component learning rates
156  am_nnet.GetNnet().ScaleLearningRates(learning_rate_scales);
157  }
158 
159  if (scales != "") {
160  std::vector<BaseFloat> scales_vec;
161  if (!SplitStringToFloats(scales, ":", false, &scales_vec)
162  || static_cast<int32>(scales_vec.size()) !=
163  am_nnet.GetNnet().NumUpdatableComponents()) {
164  KALDI_ERR << "Expected --scales option to be a "
165  << "colon-separated string with "
166  << am_nnet.GetNnet().NumUpdatableComponents()
167  << " elements, instead got \"" << scales << '"';
168  }
169  SubVector<BaseFloat> scales_vector(&(scales_vec[0]),
170  scales_vec.size());
171  am_nnet.GetNnet().ScaleComponents(scales_vector);
172  }
173 
174  if (truncate >= 0) {
175  am_nnet.GetNnet().Resize(truncate);
176  if (am_nnet.GetNnet().OutputDim() != am_nnet.Priors().Dim()) {
177  Vector<BaseFloat> empty_priors;
178  am_nnet.SetPriors(empty_priors); // so dims don't disagree.
179  }
180  }
181 
182  if (remove_dropout) am_nnet.GetNnet().RemoveDropout();
183 
184  if (dropout_scale != -1.0) am_nnet.GetNnet().SetDropoutScale(dropout_scale);
185 
186  if (remove_preconditioning) am_nnet.GetNnet().RemovePreconditioning();
187 
188  if (collapse) am_nnet.GetNnet().Collapse(match_updatableness);
189 
190  if (stats_from != "") {
191  // Copy the stats associated with the layers descending from
192  // NonlinearComponent.
193  bool binary;
194  Input ki(stats_from, &binary);
195  TransitionModel trans_model;
196  trans_model.Read(ki.Stream(), binary);
197  AmNnet am_nnet_stats;
198  am_nnet_stats.Read(ki.Stream(), binary);
199  am_nnet.GetNnet().CopyStatsFrom(am_nnet_stats.GetNnet());
200  }
201 
202  {
203  Output ko(nnet_wxfilename, binary_write);
204  trans_model.Write(ko.Stream(), binary_write);
205  am_nnet.Write(ko.Stream(), binary_write);
206  }
207  KALDI_LOG << "Copied neural net from " << nnet_rxfilename
208  << " to " << nnet_wxfilename;
209  return 0;
210  } catch(const std::exception &e) {
211  std::cerr << e.what() << '\n';
212  return -1;
213  }
214 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool SplitStringToFloats(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< F > *out)
Definition: text-utils.cc:30
void CopyStatsFrom(const Nnet &nnet)
Copies only the statistics in layers of type NonlinearComponewnt, from this neural net...
Definition: nnet-nnet.cc:447
int32 NumUpdatableComponents() const
Returns the number of updatable components.
Definition: nnet-nnet.cc:413
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
int32 OutputDim() const
The output dimension of the network – typically the number of pdfs.
Definition: nnet-nnet.cc:31
kaldi::int32 int32
void Resize(int32 num_components)
Removes final components from the neural network (used for debugging).
Definition: nnet-nnet.cc:490
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
const VectorBase< BaseFloat > & Priors() const
Definition: am-nnet.h:67
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
#define KALDI_ERR
Definition: kaldi-error.h:147
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238
void ScaleComponents(const VectorBase< BaseFloat > &scales)
Scales the parameters of each of the updatable components.
Definition: nnet-nnet.cc:421
void Write(std::ostream &os, bool binary) const
void RemovePreconditioning()
Replace any components of type AffineComponentPreconditioned with components of type AffineComponent...
Definition: nnet-nnet.cc:531
A class representing a vector.
Definition: kaldi-vector.h:406
void SetLearningRates(BaseFloat learning_rates)
Set all the learning rates in the neural net to this value.
Definition: nnet-nnet.cc:346
void SetDropoutScale(BaseFloat scale)
Calls SetDropoutScale for all the dropout nodes.
Definition: nnet-nnet.cc:516
void RemoveDropout()
Excise any components of type DropoutComponent or AdditiveNoiseComponent.
Definition: nnet-nnet.cc:497
void ScaleLearningRates(BaseFloat factor)
Scale all the learning rates in the neural net by this factor.
Definition: nnet-nnet.cc:313
void Collapse(bool match_updatableness)
Where possible, collapse multiple affine or linear components in a sequence into a single one by comp...
Definition: nnet-nnet.cc:730
void SetPriors(const VectorBase< BaseFloat > &priors)
Definition: am-nnet.cc:44
#define KALDI_LOG
Definition: kaldi-error.h:153
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
Definition: kaldi-vector.h:501
const Nnet & GetNnet() const
Definition: am-nnet.h:61