nnet-insert.cc File Reference
Include dependency graph for nnet-insert.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file nnet-insert.cc.

References ParseOptions::GetArg(), Nnet::GetComponent(), AmNnet::GetNnet(), kaldi::nnet2::IndexOfSoftmaxLayer(), Component::Info(), Component::InputDim(), kaldi::nnet2::InsertComponents(), KALDI_ERR, KALDI_LOG, ParseOptions::NumArgs(), Nnet::NumComponents(), UpdatableComponent::PerturbParams(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), UpdatableComponent::SetZero(), Output::Stream(), Input::Stream(), AmNnet::Write(), and TransitionModel::Write().

27  {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32 
33  const char *usage =
34  "Insert components into a neural network-based acoustic model.\n"
35  "This is mostly intended for adding new hidden layers to neural networks.\n"
36  "You can either specify the option --insert-at=n (specifying the index of\n"
37  "the component after which you want your neural network inserted), or by\n"
38  "default this program will insert it just before the component before the\n"
39  "softmax component. CAUTION: It will also randomize the parameters of the\n"
40  "component before the softmax (typically AffineComponent), with stddev equal\n"
41  "to the --stddev-factor option (default 0.1), times the inverse square root\n"
42  "of the number of inputs to that component.\n"
43  "Set --randomize-next-component=false to turn this off.\n"
44  "\n"
45  "Usage: nnet-insert [options] <nnet-in> <raw-nnet-to-insert-in> <nnet-out>\n"
46  "e.g.:\n"
47  " nnet-insert 1.nnet \"nnet-init hidden_layer.config -|\" 2.nnet\n";
48 
49  bool binary_write = true;
50  bool randomize_next_component = true;
51  int32 insert_at = -1;
52  BaseFloat stddev_factor = 0.1;
53  int32 srand_seed = 0;
54 
55  ParseOptions po(usage);
56 
57  po.Register("binary", &binary_write, "Write output in binary mode");
58  po.Register("randomize-next-component", &randomize_next_component,
59  "If true, randomize the parameters of the next component after "
60  "what we insert (which must be updatable).");
61  po.Register("insert-at", &insert_at, "Inserts new components before the "
62  "specified component (note: indexes are zero-based). If <0, "
63  "inserts before the component before the softmax.");
64  po.Register("stddev-factor", &stddev_factor, "Factor on the standard "
65  "deviation when randomizing next component (only relevant if "
66  "--randomize-next-component=true");
67  po.Register("srand", &srand_seed, "Seed for random number generator");
68 
69  po.Read(argc, argv);
70  srand(srand_seed);
71 
72  if (po.NumArgs() != 3) {
73  po.PrintUsage();
74  exit(1);
75  }
76 
77  std::string nnet_rxfilename = po.GetArg(1),
78  raw_nnet_rxfilename = po.GetArg(2),
79  nnet_wxfilename = po.GetArg(3);
80 
81  TransitionModel trans_model;
82  AmNnet am_nnet;
83  {
84  bool binary;
85  Input ki(nnet_rxfilename, &binary);
86  trans_model.Read(ki.Stream(), binary);
87  am_nnet.Read(ki.Stream(), binary);
88  }
89 
90  Nnet src_nnet; // the one we'll insert.
91  ReadKaldiObject(raw_nnet_rxfilename, &src_nnet);
92 
93  if (insert_at == -1) {
94  if ((insert_at = IndexOfSoftmaxLayer(am_nnet.GetNnet())) == -1)
95  KALDI_ERR << "We don't know where to insert the new components: "
96  "the neural net doesn't have exactly one softmax component, "
97  "and you didn't use the --insert-at option.";
98  insert_at--; // we want to insert before the linearity before
99  // the softmax layer.
100  }
101 
102  // This function is declared in nnet-functions.h
103  InsertComponents(src_nnet,
104  insert_at,
105  &(am_nnet.GetNnet()));
106  KALDI_LOG << "Inserted " << src_nnet.NumComponents() << " components at "
107  << "position " << insert_at;
108 
109  if (randomize_next_component) {
110  int32 c = insert_at + src_nnet.NumComponents();
111  kaldi::nnet2::Component *component = &(am_nnet.GetNnet().GetComponent(c));
112  UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(component);
113  if (!uc)
114  KALDI_ERR << "You have --randomize-next-component=true, but the "
115  << "component to randomize is not updatable: "
116  << component->Info();
117  bool treat_as_gradient = false;
118  uc->SetZero(treat_as_gradient);
119  BaseFloat stddev = stddev_factor /
120  std::sqrt(static_cast<BaseFloat>(uc->InputDim()));
121  uc->PerturbParams(stddev);
122  KALDI_LOG << "Randomized component index " << c << " with stddev "
123  << stddev;
124  }
125 
126 
127  {
128  Output ko(nnet_wxfilename, binary_write);
129  trans_model.Write(ko.Stream(), binary_write);
130  am_nnet.Write(ko.Stream(), binary_write);
131  }
132  KALDI_LOG << "Write neural-net acoustic model to " << nnet_wxfilename;
133  return 0;
134  } catch(const std::exception &e) {
135  std::cerr << e.what() << '\n';
136  return -1;
137  }
138 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
virtual int32 InputDim() const =0
Get size of input vectors.
int32 IndexOfSoftmaxLayer(const Nnet &nnet)
If "nnet" has exactly one softmax layer, this function will return its index; otherwise it will retur...
Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
virtual void SetZero(bool treat_as_gradient)=0
Set parameters to zero, and if treat_as_gradient is true, we&#39;ll be treating this as a gradient so set...
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
#define KALDI_ERR
Definition: kaldi-error.h:147
void InsertComponents(const Nnet &src_nnet, int32 c_to_insert, Nnet *dest_nnet)
Inserts the components of one neural network into a particular place in the other one...
virtual std::string Info() const
virtual void PerturbParams(BaseFloat stddev)=0
We introduce a new virtual function that only applies to class UpdatableComponent.
void Write(std::ostream &os, bool binary) const
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...