nnet3-copy.cc File Reference
#include <typeinfo>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "hmm/transition-model.h"
#include "nnet3/am-nnet-simple.h"
#include "nnet3/nnet-utils.h"
Include dependency graph for nnet3-copy.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 28 of file nnet3-copy.cc.

References kaldi::nnet3::CollapseModel(), ParseOptions::GetArg(), rnnlm::i, KALDI_LOG, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), Nnet::ReadConfig(), kaldi::nnet3::ReadEditConfig(), kaldi::ReadKaldiObject(), ParseOptions::Register(), kaldi::nnet3::ScaleNnet(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), kaldi::nnet3::SetLearningRate(), Input::Stream(), and kaldi::WriteKaldiObject().

28  {
29  try {
30  using namespace kaldi;
31  using namespace kaldi::nnet3;
32  typedef kaldi::int32 int32;
33 
34  const char *usage =
35  "Copy 'raw' nnet3 neural network to standard output\n"
36  "Also supports setting all the learning rates to a value\n"
37  "(the --learning-rate option)\n"
38  "\n"
39  "Usage: nnet3-copy [options] <nnet-in> <nnet-out>\n"
40  "e.g.:\n"
41  " nnet3-copy --binary=false 0.raw text.raw\n";
42 
43  bool binary_write = true;
44  BaseFloat learning_rate = -1;
45  std::string nnet_config, edits_config, edits_str;
46  BaseFloat scale = 1.0;
47  bool prepare_for_test = false;
48 
49  ParseOptions po(usage);
50  po.Register("binary", &binary_write, "Write output in binary mode");
51  po.Register("learning-rate", &learning_rate,
52  "If supplied, all the learning rates of updatable components"
53  "are set to this value.");
54  po.Register("nnet-config", &nnet_config,
55  "Name of nnet3 config file that can be used to add or replace "
56  "components or nodes of the neural network (the same as you "
57  "would give to nnet3-init).");
58  po.Register("edits-config", &edits_config,
59  "Name of edits-config file that can be used to modify the network "
60  "(applied after nnet-config). See comments for ReadEditConfig()"
61  "in nnet3/nnet-utils.h to see currently supported commands.");
62  po.Register("edits", &edits_str,
63  "Can be used as an inline alternative to edits-config; semicolons "
64  "will be converted to newlines before parsing. E.g. "
65  "'--edits=remove-orphans'.");
66  po.Register("scale", &scale, "The parameter matrices are scaled"
67  " by the specified value.");
68  po.Register("prepare-for-test", &prepare_for_test,
69  "If true, prepares the model for test time (may reduce model size "
70  "slightly. Involves setting test mode in dropout and batch-norm "
71  "components, and calling CollapseModel() which may remove some "
72  "components.");
73  po.Read(argc, argv);
74 
75  if (po.NumArgs() != 2) {
76  po.PrintUsage();
77  exit(1);
78  }
79 
80  std::string raw_nnet_rxfilename = po.GetArg(1),
81  raw_nnet_wxfilename = po.GetArg(2);
82 
83  Nnet nnet;
84  ReadKaldiObject(raw_nnet_rxfilename, &nnet);
85 
86  if (!nnet_config.empty()) {
87  Input ki(nnet_config);
88  nnet.ReadConfig(ki.Stream());
89  }
90 
91  if (learning_rate >= 0)
92  SetLearningRate(learning_rate, &nnet);
93 
94  if (scale != 1.0)
95  ScaleNnet(scale, &nnet);
96 
97  if (!edits_config.empty()) {
98  Input ki(edits_config);
99  ReadEditConfig(ki.Stream(), &nnet);
100  }
101  if (!edits_str.empty()) {
102  for (size_t i = 0; i < edits_str.size(); i++)
103  if (edits_str[i] == ';')
104  edits_str[i] = '\n';
105  std::istringstream is(edits_str);
106  ReadEditConfig(is, &nnet);
107  }
108  if (prepare_for_test) {
109  SetBatchnormTestMode(true, &nnet);
110  SetDropoutTestMode(true, &nnet);
112  }
113  WriteKaldiObject(nnet, raw_nnet_wxfilename, binary_write);
114  KALDI_LOG << "Copied raw neural net from " << raw_nnet_rxfilename
115  << " to " << raw_nnet_wxfilename;
116 
117  return 0;
118  } catch(const std::exception &e) {
119  std::cerr << e.what() << '\n';
120  return -1;
121  }
122 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ReadConfig(std::istream &config_file)
Definition: nnet-nnet.cc:189
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
kaldi::int32 int32
void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet)
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConf...
Definition: nnet-utils.cc:1234
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void SetLearningRate(BaseFloat learning_rate, Nnet *nnet)
Sets the underlying learning rate for all the components in the nnet to this value.
Definition: nnet-utils.cc:276
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:153
Config class for the CollapseModel function.
Definition: nnet-utils.h:240