nnet3-copy.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-copy.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Xingyu Na
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <typeinfo>
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "hmm/transition-model.h"
25 #include "nnet3/am-nnet-simple.h"
26 #include "nnet3/nnet-utils.h"
27 
28 int main(int argc, char *argv[]) {
29  try {
30  using namespace kaldi;
31  using namespace kaldi::nnet3;
32  typedef kaldi::int32 int32;
33 
34  const char *usage =
35  "Copy 'raw' nnet3 neural network to standard output\n"
36  "Also supports setting all the learning rates to a value\n"
37  "(the --learning-rate option)\n"
38  "\n"
39  "Usage: nnet3-copy [options] <nnet-in> <nnet-out>\n"
40  "e.g.:\n"
41  " nnet3-copy --binary=false 0.raw text.raw\n";
42 
43  bool binary_write = true;
44  BaseFloat learning_rate = -1;
45  std::string nnet_config, edits_config, edits_str;
46  BaseFloat scale = 1.0;
47  bool prepare_for_test = false;
48 
49  ParseOptions po(usage);
50  po.Register("binary", &binary_write, "Write output in binary mode");
51  po.Register("learning-rate", &learning_rate,
52  "If supplied, all the learning rates of updatable components"
53  "are set to this value.");
54  po.Register("nnet-config", &nnet_config,
55  "Name of nnet3 config file that can be used to add or replace "
56  "components or nodes of the neural network (the same as you "
57  "would give to nnet3-init).");
58  po.Register("edits-config", &edits_config,
59  "Name of edits-config file that can be used to modify the network "
60  "(applied after nnet-config). See comments for ReadEditConfig()"
61  "in nnet3/nnet-utils.h to see currently supported commands.");
62  po.Register("edits", &edits_str,
63  "Can be used as an inline alternative to edits-config; semicolons "
64  "will be converted to newlines before parsing. E.g. "
65  "'--edits=remove-orphans'.");
66  po.Register("scale", &scale, "The parameter matrices are scaled"
67  " by the specified value.");
68  po.Register("prepare-for-test", &prepare_for_test,
69  "If true, prepares the model for test time (may reduce model size "
70  "slightly. Involves setting test mode in dropout and batch-norm "
71  "components, and calling CollapseModel() which may remove some "
72  "components.");
73  po.Read(argc, argv);
74 
75  if (po.NumArgs() != 2) {
76  po.PrintUsage();
77  exit(1);
78  }
79 
80  std::string raw_nnet_rxfilename = po.GetArg(1),
81  raw_nnet_wxfilename = po.GetArg(2);
82 
83  Nnet nnet;
84  ReadKaldiObject(raw_nnet_rxfilename, &nnet);
85 
86  if (!nnet_config.empty()) {
87  Input ki(nnet_config);
88  nnet.ReadConfig(ki.Stream());
89  }
90 
91  if (learning_rate >= 0)
92  SetLearningRate(learning_rate, &nnet);
93 
94  if (scale != 1.0)
95  ScaleNnet(scale, &nnet);
96 
97  if (!edits_config.empty()) {
98  Input ki(edits_config);
99  ReadEditConfig(ki.Stream(), &nnet);
100  }
101  if (!edits_str.empty()) {
102  for (size_t i = 0; i < edits_str.size(); i++)
103  if (edits_str[i] == ';')
104  edits_str[i] = '\n';
105  std::istringstream is(edits_str);
106  ReadEditConfig(is, &nnet);
107  }
108  if (prepare_for_test) {
109  SetBatchnormTestMode(true, &nnet);
110  SetDropoutTestMode(true, &nnet);
112  }
113  WriteKaldiObject(nnet, raw_nnet_wxfilename, binary_write);
114  KALDI_LOG << "Copied raw neural net from " << raw_nnet_rxfilename
115  << " to " << raw_nnet_wxfilename;
116 
117  return 0;
118  } catch(const std::exception &e) {
119  std::cerr << e.what() << '\n';
120  return -1;
121  }
122 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ReadConfig(std::istream &config_file)
Definition: nnet-nnet.cc:189
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
kaldi::int32 int32
void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet)
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConf...
Definition: nnet-utils.cc:1234
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
This file contains some miscellaneous functions dealing with class Nnet.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
int main(int argc, char *argv[])
Definition: nnet3-copy.cc:28
void SetLearningRate(BaseFloat learning_rate, Nnet *nnet)
Sets the underlying learning rate for all the components in the nnet to this value.
Definition: nnet-utils.cc:276
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
#define KALDI_LOG
Definition: kaldi-error.h:153
Config class for the CollapseModel function.
Definition: nnet-utils.h:240