nnet-copy.cc
Go to the documentation of this file.
1 // nnetbin/nnet-copy.cc
2 
3 // Copyright 2012-2015 Brno University of Technology (author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "nnet/nnet-nnet.h"
24 
25 int main(int argc, char *argv[]) {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet1;
29  typedef kaldi::int32 int32;
30 
31  const char *usage =
32  "Copy Neural Network model (and possibly change binary/text format)\n"
33  "Usage: nnet-copy [options] <model-in> <model-out>\n"
34  "e.g.:\n"
35  " nnet-copy --binary=false nnet.mdl nnet_txt.mdl\n";
36 
37  bool binary_write = true;
38  int32 remove_first_components = 0;
39  int32 remove_last_components = 0;
40  BaseFloat dropout_rate = -1.0;
41 
42  ParseOptions po(usage);
43  po.Register("binary", &binary_write, "Write output in binary mode");
44 
45  po.Register("remove-first-layers", &remove_first_components,
46  "Deprecated, please use --remove-first-components");
47  po.Register("remove-last-layers", &remove_last_components,
48  "Deprecated, please use --remove-last-components");
49 
50  po.Register("remove-first-components", &remove_first_components,
51  "Remove N first Components from the Nnet");
52  po.Register("remove-last-components", &remove_last_components,
53  "Remove N last layers Components from the Nnet");
54 
55  po.Register("dropout-rate", &dropout_rate,
56  "Probability that neuron is dropped"
57  "(-1.0 keeps original value).");
58 
59  std::string from_parallel_component;
60  po.Register("from-parallel-component", &from_parallel_component,
61  "Extract nested network from parallel component (two possibilities: "
62  "'3' = search for ParallelComponent and get its 3rd network; "
63  "'1:3' = get 3nd network from 1st component; ID = 1..N).");
64 
65  po.Read(argc, argv);
66 
67  if (po.NumArgs() != 2) {
68  po.PrintUsage();
69  exit(1);
70  }
71 
72  std::string model_in_filename = po.GetArg(1),
73  model_out_filename = po.GetArg(2);
74 
75  // load the network
76  Nnet nnet;
77  {
78  bool binary_read;
79  Input ki(model_in_filename, &binary_read);
80  nnet.Read(ki.Stream(), binary_read);
81  }
82 
83  // eventually replace 'nnet' by nested network from <ParallelComponent>,
84  if (from_parallel_component != "") {
85  std::vector<int32> component_id_nested_id;
86  kaldi::SplitStringToIntegers(from_parallel_component, ":", false,
87  &component_id_nested_id);
88  // parse the argument,
89  int32 component_id = -1, nested_id = 0;
90  switch (component_id_nested_id.size()) {
91  case 1:
92  nested_id = component_id_nested_id[0];
93  break;
94  case 2:
95  component_id = component_id_nested_id[0];
96  nested_id = component_id_nested_id[1];
97  break;
98  default:
99  KALDI_ERR << "Check the csl '--from-parallel-component='"
100  << from_parallel_component
101  << " There must be 1 or 2 elements.";
102  }
103  // search for first <ParallelComponent> (we don't know component_id yet),
104  if (component_id == -1) {
105  for (int32 i = 0; i < nnet.NumComponents(); i++) {
107  component_id = i+1;
108  break;
109  }
110  }
111  }
112  // replace the nnet,
113  KALDI_ASSERT(nnet.GetComponent(component_id-1).GetType() ==
115  ParallelComponent& parallel_comp =
116  dynamic_cast<ParallelComponent&>(nnet.GetComponent(component_id-1));
117  nnet = parallel_comp.GetNestedNnet(nested_id-1); // replace!
118  }
119 
120  // optionally remove N first components,
121  if (remove_first_components > 0) {
122  for (int32 i = 0; i < remove_first_components; i++) {
123  nnet.RemoveComponent(0);
124  }
125  }
126 
127  // optionally remove N last components,
128  if (remove_last_components > 0) {
129  for (int32 i = 0; i < remove_last_components; i++) {
130  nnet.RemoveLastComponent();
131  }
132  }
133 
134  // dropout,
135  if (dropout_rate != -1.0) {
136  nnet.SetDropoutRate(dropout_rate);
137  }
138 
139  // store the network,
140  {
141  Output ko(model_out_filename, binary_write);
142  nnet.Write(ko.Stream(), binary_write);
143  }
144 
145  KALDI_LOG << "Written 'nnet1' to " << model_out_filename;
146  return 0;
147  } catch(const std::exception &e) {
148  std::cerr << e.what();
149  return -1;
150  }
151 }
void RemoveLastComponent()
Remove the last of the Components,.
Definition: nnet-nnet.cc:206
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int32 NumComponents() const
Returns the number of &#39;Components&#39; which form the NN.
Definition: nnet-nnet.h:66
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to &#39;wxfilename&#39;,.
Definition: nnet-nnet.cc:367
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
int main(int argc, char *argv[])
Definition: nnet-copy.cc:25
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
void RemoveComponent(int32 c)
Remove c&#39;th component,.
Definition: nnet-nnet.cc:199
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
#define KALDI_LOG
Definition: kaldi-error.h:153