28 int main(
int argc,
char *argv[]) {
30 using namespace kaldi;
35 "Copy nnet3 neural-net acoustic model file; supports conversion\n" 36 "to raw model (--raw=true).\n" 37 "Also supports setting all learning rates to a supplied\n" 38 "value (the --learning-rate option),\n" 39 "and supports replacing the raw nnet in the model (the Nnet)\n" 40 "with a provided raw nnet (the --set-raw-nnet option)\n" 42 "Usage: nnet3-am-copy [options] <nnet-in> <nnet-out>\n" 44 " nnet3-am-copy --binary=false 1.mdl text.mdl\n" 45 " nnet3-am-copy --raw=true 1.mdl 1.raw\n";
47 bool binary_write =
true,
50 std::string set_raw_nnet =
"";
51 bool convert_repeated_to_block =
false;
53 bool prepare_for_test =
false;
54 std::string nnet_config, edits_config, edits_str;
57 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
58 po.
Register(
"raw", &raw,
"If true, write only 'raw' neural net " 59 "without transition model and priors.");
60 po.
Register(
"set-raw-nnet", &set_raw_nnet,
61 "Set the raw nnet inside the model to the one provided in " 62 "the option string (interpreted as an rxfilename). Done " 63 "before the learning-rate is changed.");
64 po.
Register(
"convert-repeated-to-block", &convert_repeated_to_block,
65 "Convert all RepeatedAffineComponents and " 66 "NaturalGradientRepeatedAffineComponents to " 67 "BlockAffineComponents in the model. Done after set-raw-nnet.");
68 po.
Register(
"nnet-config", &nnet_config,
69 "Name of nnet3 config file that can be used to add or replace " 70 "components or nodes of the neural network (the same as you " 71 "would give to nnet3-init).");
72 po.
Register(
"edits-config", &edits_config,
73 "Name of edits-config file that can be used to modify the network " 74 "(applied after nnet-config). See comments for ReadEditConfig()" 75 "in nnet3/nnet-utils.h to see currently supported commands.");
77 "Can be used as an inline alternative to --edits-config; " 78 "semicolons will be converted to newlines before parsing. E.g. " 79 "'--edits=remove-orphans'.");
80 po.
Register(
"learning-rate", &learning_rate,
81 "If supplied, all the learning rates of updatable components" 82 " are set to this value.");
83 po.
Register(
"scale", &scale,
"The parameter matrices are scaled" 84 " by the specified value.");
85 po.
Register(
"prepare-for-test", &prepare_for_test,
86 "If true, prepares the model for test time (may reduce model size " 87 "slightly. Involves setting test mode in dropout and batch-norm " 88 "components, and calling CollapseModel() which may remove some " 98 std::string nnet_rxfilename = po.
GetArg(1),
99 nnet_wxfilename = po.
GetArg(2);
105 Input ki(nnet_rxfilename, &binary);
110 if (!set_raw_nnet.empty()) {
116 if (!nnet_config.empty()) {
117 Input ki(nnet_config);
121 if(convert_repeated_to_block)
124 if (learning_rate >= 0)
127 if (!edits_config.empty()) {
128 Input ki(edits_config);
131 if (!edits_str.empty()) {
132 for (
size_t i = 0;
i < edits_str.size();
i++)
133 if (edits_str[
i] ==
';')
135 std::istringstream is(edits_str);
145 if (prepare_for_test) {
153 KALDI_LOG <<
"Copied neural net from " << nnet_rxfilename
154 <<
" to raw format as " << nnet_wxfilename;
157 Output ko(nnet_wxfilename, binary_write);
160 KALDI_LOG <<
"Copied neural net from " << nnet_rxfilename
161 <<
" to " << nnet_wxfilename;
164 }
catch(
const std::exception &e) {
165 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
void ReadConfig(std::istream &config_file)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
const Nnet & GetNnet() const
void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet)
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConf...
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
This file contains some miscellaneous functions dealing with class Nnet.
void ConvertRepeatedToBlockAffine(CompositeComponent *c_component)
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
void SetLearningRate(BaseFloat learning_rate, Nnet *nnet)
Sets the underlying learning rate for all the components in the nnet to this value.
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
void SetNnet(const Nnet &nnet)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void SetContext()
This function works out the left_context_ and right_context_ variables from the network (it's a rathe...
void Write(std::ostream &os, bool binary) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
void Write(std::ostream &os, bool binary) const
Config class for the CollapseModel function.
int main(int argc, char *argv[])