25 int main(
int argc,
char *argv[]) {
27 using namespace kaldi;
32 "Copy Neural Network model (and possibly change binary/text format)\n" 33 "Usage: nnet-copy [options] <model-in> <model-out>\n" 35 " nnet-copy --binary=false nnet.mdl nnet_txt.mdl\n";
37 bool binary_write =
true;
38 int32 remove_first_components = 0;
39 int32 remove_last_components = 0;
43 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
45 po.
Register(
"remove-first-layers", &remove_first_components,
46 "Deprecated, please use --remove-first-components");
47 po.
Register(
"remove-last-layers", &remove_last_components,
48 "Deprecated, please use --remove-last-components");
50 po.
Register(
"remove-first-components", &remove_first_components,
51 "Remove N first Components from the Nnet");
52 po.
Register(
"remove-last-components", &remove_last_components,
53 "Remove N last layers Components from the Nnet");
55 po.
Register(
"dropout-rate", &dropout_rate,
56 "Probability that neuron is dropped" 57 "(-1.0 keeps original value).");
59 std::string from_parallel_component;
60 po.
Register(
"from-parallel-component", &from_parallel_component,
61 "Extract nested network from parallel component (two possibilities: " 62 "'3' = search for ParallelComponent and get its 3rd network; " 63 "'1:3' = get 3nd network from 1st component; ID = 1..N).");
72 std::string model_in_filename = po.
GetArg(1),
73 model_out_filename = po.
GetArg(2);
79 Input ki(model_in_filename, &binary_read);
84 if (from_parallel_component !=
"") {
85 std::vector<int32> component_id_nested_id;
87 &component_id_nested_id);
89 int32 component_id = -1, nested_id = 0;
90 switch (component_id_nested_id.size()) {
92 nested_id = component_id_nested_id[0];
95 component_id = component_id_nested_id[0];
96 nested_id = component_id_nested_id[1];
99 KALDI_ERR <<
"Check the csl '--from-parallel-component='" 100 << from_parallel_component
101 <<
" There must be 1 or 2 elements.";
104 if (component_id == -1) {
117 nnet = parallel_comp.GetNestedNnet(nested_id-1);
121 if (remove_first_components > 0) {
122 for (int32
i = 0;
i < remove_first_components;
i++) {
128 if (remove_last_components > 0) {
129 for (int32
i = 0;
i < remove_last_components;
i++) {
135 if (dropout_rate != -1.0) {
141 Output ko(model_out_filename, binary_write);
145 KALDI_LOG <<
"Written 'nnet1' to " << model_out_filename;
147 }
catch(
const std::exception &e) {
148 std::cerr << e.what();
void RemoveLastComponent()
Remove the last of the Components,.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
int32 NumComponents() const
Returns the number of 'Components' which form the NN.
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
void Register(const std::string &name, bool *ptr, const std::string &doc)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
int main(int argc, char *argv[])
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
void RemoveComponent(int32 c)
Remove c'th component,.
const Component & GetComponent(int32 c) const
Component accessor,.
virtual ComponentType GetType() const =0
Get Type Identification of the component,.