28 int main(
int argc,
char *argv[]) {
30 using namespace kaldi;
35 "Copy 'raw' nnet3 neural network to standard output\n" 36 "Also supports setting all the learning rates to a value\n" 37 "(the --learning-rate option)\n" 39 "Usage: nnet3-copy [options] <nnet-in> <nnet-out>\n" 41 " nnet3-copy --binary=false 0.raw text.raw\n";
43 bool binary_write =
true;
45 std::string nnet_config, edits_config, edits_str;
47 bool prepare_for_test =
false;
50 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
51 po.
Register(
"learning-rate", &learning_rate,
52 "If supplied, all the learning rates of updatable components" 53 "are set to this value.");
54 po.
Register(
"nnet-config", &nnet_config,
55 "Name of nnet3 config file that can be used to add or replace " 56 "components or nodes of the neural network (the same as you " 57 "would give to nnet3-init).");
58 po.
Register(
"edits-config", &edits_config,
59 "Name of edits-config file that can be used to modify the network " 60 "(applied after nnet-config). See comments for ReadEditConfig()" 61 "in nnet3/nnet-utils.h to see currently supported commands.");
63 "Can be used as an inline alternative to edits-config; semicolons " 64 "will be converted to newlines before parsing. E.g. " 65 "'--edits=remove-orphans'.");
66 po.
Register(
"scale", &scale,
"The parameter matrices are scaled" 67 " by the specified value.");
68 po.
Register(
"prepare-for-test", &prepare_for_test,
69 "If true, prepares the model for test time (may reduce model size " 70 "slightly. Involves setting test mode in dropout and batch-norm " 71 "components, and calling CollapseModel() which may remove some " 80 std::string raw_nnet_rxfilename = po.
GetArg(1),
81 raw_nnet_wxfilename = po.
GetArg(2);
86 if (!nnet_config.empty()) {
87 Input ki(nnet_config);
91 if (learning_rate >= 0)
97 if (!edits_config.empty()) {
98 Input ki(edits_config);
101 if (!edits_str.empty()) {
102 for (
size_t i = 0;
i < edits_str.size();
i++)
103 if (edits_str[
i] ==
';')
105 std::istringstream is(edits_str);
108 if (prepare_for_test) {
114 KALDI_LOG <<
"Copied raw neural net from " << raw_nnet_rxfilename
115 <<
" to " << raw_nnet_wxfilename;
118 }
catch(
const std::exception &e) {
119 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
void ReadConfig(std::istream &config_file)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet)
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConf...
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
This file contains some miscellaneous functions dealing with class Nnet.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
int main(int argc, char *argv[])
void SetLearningRate(BaseFloat learning_rate, Nnet *nnet)
Sets the underlying learning rate for all the components in the nnet to this value.
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Config class for the CollapseModel function.