nnet3-am-copy.cc File Reference
#include <typeinfo>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "hmm/transition-model.h"
#include "nnet3/am-nnet-simple.h"
#include "nnet3/nnet-utils.h"
Include dependency graph for nnet3-am-copy.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file nnet3-am-copy.cc.

References kaldi::nnet3::CollapseModel(), kaldi::nnet3::ConvertRepeatedToBlockAffine(), ParseOptions::GetArg(), AmNnetSimple::GetNnet(), rnnlm::i, KALDI_LOG, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), Nnet::ReadConfig(), kaldi::nnet3::ReadEditConfig(), kaldi::ReadKaldiObject(), ParseOptions::Register(), kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetBatchnormTestMode(), AmNnetSimple::SetContext(), kaldi::nnet3::SetDropoutTestMode(), kaldi::nnet3::SetLearningRate(), AmNnetSimple::SetNnet(), Output::Stream(), Input::Stream(), AmNnetSimple::Write(), TransitionModel::Write(), and kaldi::WriteKaldiObject().

28  {
29  try {
30  using namespace kaldi;
31  using namespace kaldi::nnet3;
32  typedef kaldi::int32 int32;
33 
34  const char *usage =
35  "Copy nnet3 neural-net acoustic model file; supports conversion\n"
36  "to raw model (--raw=true).\n"
37  "Also supports setting all learning rates to a supplied\n"
38  "value (the --learning-rate option),\n"
39  "and supports replacing the raw nnet in the model (the Nnet)\n"
40  "with a provided raw nnet (the --set-raw-nnet option)\n"
41  "\n"
42  "Usage: nnet3-am-copy [options] <nnet-in> <nnet-out>\n"
43  "e.g.:\n"
44  " nnet3-am-copy --binary=false 1.mdl text.mdl\n"
45  " nnet3-am-copy --raw=true 1.mdl 1.raw\n";
46 
47  bool binary_write = true,
48  raw = false;
49  BaseFloat learning_rate = -1;
50  std::string set_raw_nnet = "";
51  bool convert_repeated_to_block = false;
52  BaseFloat scale = 1.0;
53  bool prepare_for_test = false;
54  std::string nnet_config, edits_config, edits_str;
55 
56  ParseOptions po(usage);
57  po.Register("binary", &binary_write, "Write output in binary mode");
58  po.Register("raw", &raw, "If true, write only 'raw' neural net "
59  "without transition model and priors.");
60  po.Register("set-raw-nnet", &set_raw_nnet,
61  "Set the raw nnet inside the model to the one provided in "
62  "the option string (interpreted as an rxfilename). Done "
63  "before the learning-rate is changed.");
64  po.Register("convert-repeated-to-block", &convert_repeated_to_block,
65  "Convert all RepeatedAffineComponents and "
66  "NaturalGradientRepeatedAffineComponents to "
67  "BlockAffineComponents in the model. Done after set-raw-nnet.");
68  po.Register("nnet-config", &nnet_config,
69  "Name of nnet3 config file that can be used to add or replace "
70  "components or nodes of the neural network (the same as you "
71  "would give to nnet3-init).");
72  po.Register("edits-config", &edits_config,
73  "Name of edits-config file that can be used to modify the network "
74  "(applied after nnet-config). See comments for ReadEditConfig()"
75  "in nnet3/nnet-utils.h to see currently supported commands.");
76  po.Register("edits", &edits_str,
77  "Can be used as an inline alternative to --edits-config; "
78  "semicolons will be converted to newlines before parsing. E.g. "
79  "'--edits=remove-orphans'.");
80  po.Register("learning-rate", &learning_rate,
81  "If supplied, all the learning rates of updatable components"
82  " are set to this value.");
83  po.Register("scale", &scale, "The parameter matrices are scaled"
84  " by the specified value.");
85  po.Register("prepare-for-test", &prepare_for_test,
86  "If true, prepares the model for test time (may reduce model size "
87  "slightly. Involves setting test mode in dropout and batch-norm "
88  "components, and calling CollapseModel() which may remove some "
89  "components.");
90 
91  po.Read(argc, argv);
92 
93  if (po.NumArgs() != 2) {
94  po.PrintUsage();
95  exit(1);
96  }
97 
98  std::string nnet_rxfilename = po.GetArg(1),
99  nnet_wxfilename = po.GetArg(2);
100 
101  TransitionModel trans_model;
102  AmNnetSimple am_nnet;
103  {
104  bool binary;
105  Input ki(nnet_rxfilename, &binary);
106  trans_model.Read(ki.Stream(), binary);
107  am_nnet.Read(ki.Stream(), binary);
108  }
109 
110  if (!set_raw_nnet.empty()) {
111  Nnet nnet;
112  ReadKaldiObject(set_raw_nnet, &nnet);
113  am_nnet.SetNnet(nnet);
114  }
115 
116  if (!nnet_config.empty()) {
117  Input ki(nnet_config);
118  am_nnet.GetNnet().ReadConfig(ki.Stream());
119  }
120 
121  if(convert_repeated_to_block)
123 
124  if (learning_rate >= 0)
125  SetLearningRate(learning_rate, &(am_nnet.GetNnet()));
126 
127  if (!edits_config.empty()) {
128  Input ki(edits_config);
129  ReadEditConfig(ki.Stream(), &(am_nnet.GetNnet()));
130  }
131  if (!edits_str.empty()) {
132  for (size_t i = 0; i < edits_str.size(); i++)
133  if (edits_str[i] == ';')
134  edits_str[i] = '\n';
135  std::istringstream is(edits_str);
136  ReadEditConfig(is, &(am_nnet.GetNnet()));
137  }
138 
139  am_nnet.SetContext(); // in case we used the config or edits-config or
140  // edits options
141 
142  if (scale != 1.0)
143  ScaleNnet(scale, &(am_nnet.GetNnet()));
144 
145  if (prepare_for_test) {
146  SetBatchnormTestMode(true, &am_nnet.GetNnet());
147  SetDropoutTestMode(true, &am_nnet.GetNnet());
149  }
150 
151  if (raw) {
152  WriteKaldiObject(am_nnet.GetNnet(), nnet_wxfilename, binary_write);
153  KALDI_LOG << "Copied neural net from " << nnet_rxfilename
154  << " to raw format as " << nnet_wxfilename;
155 
156  } else {
157  Output ko(nnet_wxfilename, binary_write);
158  trans_model.Write(ko.Stream(), binary_write);
159  am_nnet.Write(ko.Stream(), binary_write);
160  KALDI_LOG << "Copied neural net from " << nnet_rxfilename
161  << " to " << nnet_wxfilename;
162  }
163  return 0;
164  } catch(const std::exception &e) {
165  std::cerr << e.what() << '\n';
166  return -1;
167  }
168 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ReadConfig(std::istream &config_file)
Definition: nnet-nnet.cc:189
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
kaldi::int32 int32
const Nnet & GetNnet() const
void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet)
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConf...
Definition: nnet-utils.cc:1234
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void ConvertRepeatedToBlockAffine(CompositeComponent *c_component)
Definition: nnet-utils.cc:447
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void SetLearningRate(BaseFloat learning_rate, Nnet *nnet)
Sets the underlying learning rate for all the components in the nnet to this value.
Definition: nnet-utils.cc:276
void SetNnet(const Nnet &nnet)
void SetContext()
This function works out the left_context_ and right_context_ variables from the network (it&#39;s a rathe...
void Write(std::ostream &os, bool binary) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:153
void Write(std::ostream &os, bool binary) const
Config class for the CollapseModel function.
Definition: nnet-utils.h:240