raw-nnet-copy.cc
Go to the documentation of this file.
1 // nnet2bin/raw-nnet-copy.cc
2 
3 // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <typeinfo>
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "hmm/transition-model.h"
24 #include "nnet2/am-nnet.h"
25 #include "tree/context-dep.h"
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Copy a raw neural net (this version works on raw nnet2 neural nets,\n"
35  "without the transition model. Supports the 'truncate' option.\n"
36  "\n"
37  "Usage: raw-nnet-copy [options] <raw-nnet-in> <raw-nnet-out>\n"
38  "e.g.:\n"
39  " raw-nnet-copy --binary=false 1.mdl text.mdl\n"
40  "See also: nnet-to-raw-nnet, nnet-am-copy\n";
41 
42  int32 truncate = -1;
43  bool binary_write = true;
44  std::string learning_rate_scales_str = " ";
45 
46  ParseOptions po(usage);
47  po.Register("binary", &binary_write, "Write output in binary mode");
48  po.Register("truncate", &truncate, "If set, will truncate the neural net "
49  "to this many components by removing the last components.");
50  po.Register("learning-rate-scales", &learning_rate_scales_str,
51  "Colon-separated list of scaling factors for learning rates, "
52  "applied after the --learning-rate and --learning-rates options."
53  "Used to scale learning rates for particular layer types. E.g."
54  "--learning-rate-scales=AffineComponent=0.5");
55 
56  po.Read(argc, argv);
57 
58  if (po.NumArgs() != 2) {
59  po.PrintUsage();
60  exit(1);
61  }
62 
63  std::string raw_nnet_rxfilename = po.GetArg(1),
64  raw_nnet_wxfilename = po.GetArg(2);
65 
66  Nnet nnet;
67  ReadKaldiObject(raw_nnet_rxfilename, &nnet);
68 
69  if (truncate >= 0)
70  nnet.Resize(truncate);
71 
72  if (learning_rate_scales_str != " ") {
73  // parse the learning_rate_scales provided as an option
74  std::map<std::string, BaseFloat> learning_rate_scales;
75  std::vector<std::string> learning_rate_scale_vec;
76  SplitStringToVector(learning_rate_scales_str, ":", true,
77  &learning_rate_scale_vec);
78  for (int32 index = 0; index < learning_rate_scale_vec.size();
79  index++) {
80  std::vector<std::string> parts;
81  BaseFloat scale_factor;
82  SplitStringToVector(learning_rate_scale_vec[index],
83  "=", false, &parts);
84  if (!ConvertStringToReal(parts[1], &scale_factor)) {
85  KALDI_ERR << "Unknown format for --learning-rate-scales option. "
86  << "Expected format is "
87  << "--learning-rate-scales=AffineComponent=0.1:AffineComponentPreconditioned=0.5 "
88  << "instead got "
89  << learning_rate_scales_str;
90  }
91  learning_rate_scales.insert(std::pair<std::string, BaseFloat>(
92  parts[0], scale_factor));
93  }
94  // use the learning_rate_scales to scale the component learning rates
95  nnet.ScaleLearningRates(learning_rate_scales);
96  }
97 
98  WriteKaldiObject(nnet, raw_nnet_wxfilename, binary_write);
99 
100  KALDI_LOG << "Copied raw neural net from " << raw_nnet_rxfilename
101  << " to " << raw_nnet_wxfilename;
102  return 0;
103  } catch(const std::exception &e) {
104  std::cerr << e.what() << '\n';
105  return -1;
106  }
107 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
void Resize(int32 num_components)
Removes final components from the neural network (used for debugging).
Definition: nnet-nnet.cc:490
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
void ScaleLearningRates(BaseFloat factor)
Scale all the learning rates in the neural net by this factor.
Definition: nnet-nnet.cc:313
#define KALDI_LOG
Definition: kaldi-error.h:153