raw-nnet-copy.cc File Reference
#include <typeinfo>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "hmm/transition-model.h"
#include "nnet2/am-nnet.h"
#include "tree/context-dep.h"
Include dependency graph for raw-nnet-copy.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file raw-nnet-copy.cc.

References kaldi::ConvertStringToReal(), ParseOptions::GetArg(), KALDI_ERR, KALDI_LOG, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), Nnet::Resize(), Nnet::ScaleLearningRates(), kaldi::SplitStringToVector(), and kaldi::WriteKaldiObject().

27  {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Copy a raw neural net (this version works on raw nnet2 neural nets,\n"
35  "without the transition model. Supports the 'truncate' option.\n"
36  "\n"
37  "Usage: raw-nnet-copy [options] <raw-nnet-in> <raw-nnet-out>\n"
38  "e.g.:\n"
39  " raw-nnet-copy --binary=false 1.mdl text.mdl\n"
40  "See also: nnet-to-raw-nnet, nnet-am-copy\n";
41 
42  int32 truncate = -1;
43  bool binary_write = true;
44  std::string learning_rate_scales_str = " ";
45 
46  ParseOptions po(usage);
47  po.Register("binary", &binary_write, "Write output in binary mode");
48  po.Register("truncate", &truncate, "If set, will truncate the neural net "
49  "to this many components by removing the last components.");
50  po.Register("learning-rate-scales", &learning_rate_scales_str,
51  "Colon-separated list of scaling factors for learning rates, "
52  "applied after the --learning-rate and --learning-rates options."
53  "Used to scale learning rates for particular layer types. E.g."
54  "--learning-rate-scales=AffineComponent=0.5");
55 
56  po.Read(argc, argv);
57 
58  if (po.NumArgs() != 2) {
59  po.PrintUsage();
60  exit(1);
61  }
62 
63  std::string raw_nnet_rxfilename = po.GetArg(1),
64  raw_nnet_wxfilename = po.GetArg(2);
65 
66  Nnet nnet;
67  ReadKaldiObject(raw_nnet_rxfilename, &nnet);
68 
69  if (truncate >= 0)
70  nnet.Resize(truncate);
71 
72  if (learning_rate_scales_str != " ") {
73  // parse the learning_rate_scales provided as an option
74  std::map<std::string, BaseFloat> learning_rate_scales;
75  std::vector<std::string> learning_rate_scale_vec;
76  SplitStringToVector(learning_rate_scales_str, ":", true,
77  &learning_rate_scale_vec);
78  for (int32 index = 0; index < learning_rate_scale_vec.size();
79  index++) {
80  std::vector<std::string> parts;
81  BaseFloat scale_factor;
82  SplitStringToVector(learning_rate_scale_vec[index],
83  "=", false, &parts);
84  if (!ConvertStringToReal(parts[1], &scale_factor)) {
85  KALDI_ERR << "Unknown format for --learning-rate-scales option. "
86  << "Expected format is "
87  << "--learning-rate-scales=AffineComponent=0.1:AffineComponentPreconditioned=0.5 "
88  << "instead got "
89  << learning_rate_scales_str;
90  }
91  learning_rate_scales.insert(std::pair<std::string, BaseFloat>(
92  parts[0], scale_factor));
93  }
94  // use the learning_rate_scales to scale the component learning rates
95  nnet.ScaleLearningRates(learning_rate_scales);
96  }
97 
98  WriteKaldiObject(nnet, raw_nnet_wxfilename, binary_write);
99 
100  KALDI_LOG << "Copied raw neural net from " << raw_nnet_rxfilename
101  << " to " << raw_nnet_wxfilename;
102  return 0;
103  } catch(const std::exception &e) {
104  std::cerr << e.what() << '\n';
105  return -1;
106  }
107 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32
void Resize(int32 num_components)
Removes final components from the neural network (used for debugging).
Definition: nnet-nnet.cc:490
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
#define KALDI_ERR
Definition: kaldi-error.h:147
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
Definition: kaldi-io.h:257
void ScaleLearningRates(BaseFloat factor)
Scale all the learning rates in the neural net by this factor.
Definition: nnet-nnet.cc:313
#define KALDI_LOG
Definition: kaldi-error.h:153