All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet-copy.cc File Reference
Include dependency graph for nnet-copy.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 25 of file nnet-copy.cc.

References ParseOptions::GetArg(), Nnet::GetComponent(), Component::GetType(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, KALDI_LOG, Component::kParallelComponent, ParseOptions::NumArgs(), Nnet::NumComponents(), ParseOptions::PrintUsage(), ParseOptions::Read(), Nnet::Read(), ParseOptions::Register(), Nnet::RemoveComponent(), Nnet::RemoveLastComponent(), Nnet::SetDropoutRate(), kaldi::SplitStringToIntegers(), Output::Stream(), Input::Stream(), and Nnet::Write().

25  {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet1;
29  typedef kaldi::int32 int32;
30 
31  const char *usage =
32  "Copy Neural Network model (and possibly change binary/text format)\n"
33  "Usage: nnet-copy [options] <model-in> <model-out>\n"
34  "e.g.:\n"
35  " nnet-copy --binary=false nnet.mdl nnet_txt.mdl\n";
36 
37  bool binary_write = true;
38  int32 remove_first_components = 0;
39  int32 remove_last_components = 0;
40  BaseFloat dropout_rate = -1.0;
41 
42  ParseOptions po(usage);
43  po.Register("binary", &binary_write, "Write output in binary mode");
44 
45  po.Register("remove-first-layers", &remove_first_components,
46  "Deprecated, please use --remove-first-components");
47  po.Register("remove-last-layers", &remove_last_components,
48  "Deprecated, please use --remove-last-components");
49 
50  po.Register("remove-first-components", &remove_first_components,
51  "Remove N first Components from the Nnet");
52  po.Register("remove-last-components", &remove_last_components,
53  "Remove N last layers Components from the Nnet");
54 
55  po.Register("dropout-rate", &dropout_rate,
56  "Probability that neuron is dropped"
57  "(-1.0 keeps original value).");
58 
59  std::string from_parallel_component;
60  po.Register("from-parallel-component", &from_parallel_component,
61  "Extract nested network from parallel component (two possibilities: "
62  "'3' = search for ParallelComponent and get its 3rd network; "
63  "'1:3' = get 3nd network from 1st component; ID = 1..N).");
64 
65  po.Read(argc, argv);
66 
67  if (po.NumArgs() != 2) {
68  po.PrintUsage();
69  exit(1);
70  }
71 
72  std::string model_in_filename = po.GetArg(1),
73  model_out_filename = po.GetArg(2);
74 
75  // load the network
76  Nnet nnet;
77  {
78  bool binary_read;
79  Input ki(model_in_filename, &binary_read);
80  nnet.Read(ki.Stream(), binary_read);
81  }
82 
83  // eventually replace 'nnet' by nested network from <ParallelComponent>,
84  if (from_parallel_component != "") {
85  std::vector<int32> component_id_nested_id;
86  kaldi::SplitStringToIntegers(from_parallel_component, ":", false,
87  &component_id_nested_id);
88  // parse the argument,
89  int32 component_id = -1, nested_id = 0;
90  switch (component_id_nested_id.size()) {
91  case 1:
92  nested_id = component_id_nested_id[0];
93  break;
94  case 2:
95  component_id = component_id_nested_id[0];
96  nested_id = component_id_nested_id[1];
97  break;
98  default:
99  KALDI_ERR << "Check the csl '--from-parallel-component='"
100  << from_parallel_component
101  << " There must be 1 or 2 elements.";
102  }
103  // search for first <ParallelComponent> (we don't know component_id yet),
104  if (component_id == -1) {
105  for (int32 i = 0; i < nnet.NumComponents(); i++) {
107  component_id = i+1;
108  break;
109  }
110  }
111  }
112  // replace the nnet,
113  KALDI_ASSERT(nnet.GetComponent(component_id-1).GetType() ==
115  ParallelComponent& parallel_comp =
116  dynamic_cast<ParallelComponent&>(nnet.GetComponent(component_id-1));
117  nnet = parallel_comp.GetNestedNnet(nested_id-1); // replace!
118  }
119 
120  // optionally remove N first components,
121  if (remove_first_components > 0) {
122  for (int32 i = 0; i < remove_first_components; i++) {
123  nnet.RemoveComponent(0);
124  }
125  }
126 
127  // optionally remove N last components,
128  if (remove_last_components > 0) {
129  for (int32 i = 0; i < remove_last_components; i++) {
130  nnet.RemoveLastComponent();
131  }
132  }
133 
134  // dropout,
135  if (dropout_rate != -1.0) {
136  nnet.SetDropoutRate(dropout_rate);
137  }
138 
139  // store the network,
140  {
141  Output ko(model_out_filename, binary_write);
142  nnet.Write(ko.Stream(), binary_write);
143  }
144 
145  KALDI_LOG << "Written 'nnet1' to " << model_out_filename;
146  return 0;
147  } catch(const std::exception &e) {
148  std::cerr << e.what();
149  return -1;
150  }
151 }
void RemoveLastComponent()
Remove the last of the Components,.
Definition: nnet-nnet.cc:206
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:64
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
#define KALDI_ERR
Definition: kaldi-error.h:127
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
const Component & GetComponent(int32 c) const
Component accessor,.
Definition: nnet-nnet.cc:153
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Definition: nnet-nnet.cc:367
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
void RemoveComponent(int32 c)
Remove c'th component,.
Definition: nnet-nnet.cc:199
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
int32 NumComponents() const
Returns the number of 'Components' which form the NN.
Definition: nnet-nnet.h:66
#define KALDI_LOG
Definition: kaldi-error.h:133