29 using namespace kaldi;
34 "Insert components into a neural network-based acoustic model.\n" 35 "This is mostly intended for adding new hidden layers to neural networks.\n" 36 "You can either specify the option --insert-at=n (specifying the index of\n" 37 "the component after which you want your neural network inserted), or by\n" 38 "default this program will insert it just before the component before the\n" 39 "softmax component. CAUTION: It will also randomize the parameters of the\n" 40 "component before the softmax (typically AffineComponent), with stddev equal\n" 41 "to the --stddev-factor option (default 0.1), times the inverse square root\n" 42 "of the number of inputs to that component.\n" 43 "Set --randomize-next-component=false to turn this off.\n" 45 "Usage: nnet-insert [options] <nnet-in> <raw-nnet-to-insert-in> <nnet-out>\n" 47 " nnet-insert 1.nnet \"nnet-init hidden_layer.config -|\" 2.nnet\n";
49 bool binary_write =
true;
50 bool randomize_next_component =
true;
57 po.Register(
"binary", &binary_write,
"Write output in binary mode");
58 po.Register(
"randomize-next-component", &randomize_next_component,
59 "If true, randomize the parameters of the next component after " 60 "what we insert (which must be updatable).");
61 po.Register(
"insert-at", &insert_at,
"Inserts new components before the " 62 "specified component (note: indexes are zero-based). If <0, " 63 "inserts before the component before the softmax.");
64 po.Register(
"stddev-factor", &stddev_factor,
"Factor on the standard " 65 "deviation when randomizing next component (only relevant if " 66 "--randomize-next-component=true");
67 po.Register(
"srand", &srand_seed,
"Seed for random number generator");
72 if (po.NumArgs() != 3) {
77 std::string nnet_rxfilename = po.GetArg(1),
78 raw_nnet_rxfilename = po.GetArg(2),
79 nnet_wxfilename = po.GetArg(3);
85 Input ki(nnet_rxfilename, &binary);
86 trans_model.
Read(ki.Stream(), binary);
87 am_nnet.
Read(ki.Stream(), binary);
93 if (insert_at == -1) {
95 KALDI_ERR <<
"We don't know where to insert the new components: " 96 "the neural net doesn't have exactly one softmax component, " 97 "and you didn't use the --insert-at option.";
107 <<
"position " << insert_at;
109 if (randomize_next_component) {
114 KALDI_ERR <<
"You have --randomize-next-component=true, but the " 115 <<
"component to randomize is not updatable: " 116 << component->
Info();
117 bool treat_as_gradient =
false;
118 uc->
SetZero(treat_as_gradient);
120 std::sqrt(static_cast<BaseFloat>(uc->
InputDim()));
122 KALDI_LOG <<
"Randomized component index " << c <<
" with stddev " 128 Output ko(nnet_wxfilename, binary_write);
129 trans_model.
Write(ko.Stream(), binary_write);
130 am_nnet.
Write(ko.Stream(), binary_write);
132 KALDI_LOG <<
"Write neural-net acoustic model to " << nnet_wxfilename;
134 }
catch(
const std::exception &e) {
135 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
const Component & GetComponent(int32 c) const
virtual int32 InputDim() const =0
Get size of input vectors.
int32 IndexOfSoftmaxLayer(const Nnet &nnet)
If "nnet" has exactly one softmax layer, this function will return its index; otherwise it will retur...
Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.
void Read(std::istream &is, bool binary)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
virtual void SetZero(bool treat_as_gradient)=0
Set parameters to zero, and if treat_as_gradient is true, we'll be treating this as a gradient so set...
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
void InsertComponents(const Nnet &src_nnet, int32 c_to_insert, Nnet *dest_nnet)
Inserts the components of one neural network into a particular place in the other one...
virtual std::string Info() const
virtual void PerturbParams(BaseFloat stddev)=0
We introduce a new virtual function that only applies to class UpdatableComponent.
void Write(std::ostream &os, bool binary) const
const Nnet & GetNnet() const
Class UpdatableComponent is a Component which has trainable parameters and contains some global param...