nnet-utils.h
Go to the documentation of this file.
1 // nnet3/nnet-utils.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2016 Daniel Galvez
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_UTILS_H_
21 #define KALDI_NNET3_NNET_UTILS_H_
22 
23 #include "base/kaldi-common.h"
24 #include "util/kaldi-io.h"
25 #include "matrix/matrix-lib.h"
26 #include "nnet3/nnet-common.h"
28 #include "nnet3/nnet-descriptor.h"
29 #include "nnet3/nnet-computation.h"
30 #include "nnet3/nnet-example.h"
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
35 
38 
45  const Nnet &nnet,
46  const ComputationRequest &request,
47  std::vector<std::vector<bool> > *is_computable);
48 
49 
51 int32 NumOutputNodes(const Nnet &nnet);
52 
54 int32 NumInputNodes(const Nnet &nnet);
55 
58 void PerturbParams(BaseFloat stddev,
59  Nnet *nnet);
60 
61 
65 BaseFloat DotProduct(const Nnet &nnet1,
66  const Nnet &nnet2);
67 
71 void ComponentDotProducts(const Nnet &nnet1,
72  const Nnet &nnet2,
73  VectorBase<BaseFloat> *dot_prod);
74 
79 std::string PrintVectorPerUpdatableComponent(const Nnet &nnet,
80  const VectorBase<BaseFloat> &vec);
81 
89 bool IsSimpleNnet(const Nnet &nnet);
90 
92 void ZeroComponentStats(Nnet *nnet);
93 
94 
103 void ComputeSimpleNnetContext(const Nnet &nnet,
104  int32 *left_context,
105  int32 *right_context);
106 
107 
111 void SetLearningRate(BaseFloat learning_rate,
112  Nnet *nnet);
113 
115 void ScaleNnet(BaseFloat scale, Nnet *nnet);
116 
119 void SetNnetAsGradient(Nnet *nnet);
120 
121 
126 void SetRequireDirectInput(bool b, Nnet *nnet);
127 
128 
131 void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest);
132 
136 void AddNnetComponents(const Nnet &src, const Vector<BaseFloat> &alphas,
137  BaseFloat scale, Nnet *dest);
138 
140 bool NnetIsRecurrent(const Nnet &nnet);
141 
144 int32 NumParameters(const Nnet &src);
145 
148 void VectorizeNnet(const Nnet &src,
149  VectorBase<BaseFloat> *params);
150 
151 
154 void UnVectorizeNnet(const VectorBase<BaseFloat> &params,
155  Nnet *dest);
156 
158 int32 NumUpdatableComponents(const Nnet &dest);
159 
161 void FreezeNaturalGradient(bool freeze, Nnet *nnet);
162 
165 void ConvertRepeatedToBlockAffine(Nnet *nnet);
166 
172 std::string NnetInfo(const Nnet &nnet);
173 
176 void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet);
177 
178 
180 bool HasBatchnorm(const Nnet &nnet);
181 
188 void SetBatchnormTestMode(bool test_mode, Nnet *nnet);
189 
190 
195 void RecomputeStats(const std::vector<NnetExample> &egs, Nnet *nnet);
196 
197 
198 
205 void SetDropoutTestMode(bool test_mode, Nnet *nnet);
206 
214 void ResetGenerators(Nnet *nnet);
215 
219 void FindOrphanComponents(const Nnet &nnet, std::vector<int32> *components);
220 
224 void FindOrphanNodes(const Nnet &nnet, std::vector<int32> *nodes);
225 
226 
227 
241  bool collapse_dropout; // dropout then affine/conv.
242  bool collapse_batchnorm; // batchnorm then affine.
243  bool collapse_affine; // affine or fixed-affine then affine.
244  bool collapse_scale; // affine then fixed-scale.
245  CollapseModelConfig(): collapse_dropout(false),
246  collapse_batchnorm(false),
247  collapse_affine(true),
248  collapse_scale(true) { }
249 };
250 
258 void CollapseModel(const CollapseModelConfig &config,
259  Nnet *nnet);
260 
334 void ReadEditConfig(std::istream &config_file, Nnet *nnet);
335 
385 bool UpdateNnetWithMaxChange(const Nnet &delta_nnet,
386  BaseFloat max_param_change,
387  BaseFloat max_change_scale,
388  BaseFloat scale, Nnet *nnet,
389  std::vector<int32> *
390  num_max_change_per_component_applied,
391  int32 *num_max_change_global_applied);
392 
393 struct MaxChangeStats;
394 
395 // This overloaded version of UpdateNnetWithMaxChange() is a convenience
396 // wrapper for when you have a MaxChangeStats object to keep track
397 // of how many times the max-change was applied. See documentation above.
398 bool UpdateNnetWithMaxChange(const Nnet &delta_nnet,
399  BaseFloat max_param_change,
400  BaseFloat max_change_scale,
401  BaseFloat scale, Nnet *nnet,
402  MaxChangeStats *stats);
403 
404 
464 void ApplyL2Regularization(const Nnet &nnet,
465  BaseFloat l2_regularize_scale,
466  Nnet *delta_nnet);
467 
468 
474 void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale,
475  Nnet *nnet);
476 
477 
501 void ConstrainOrthonormal(Nnet *nnet);
502 
503 
514 void ConsolidateMemory(Nnet *nnet);
515 
536 int32 GetNumNvalues(const std::vector<NnetIo> &io_vec,
537  bool exhaustive);
538 
539 
544 
545  MaxChangeStats(const Nnet &nnet):
546  num_max_change_global_applied(0),
547  num_minibatches_processed(0),
548  num_max_change_per_component_applied(NumUpdatableComponents(nnet), 0) { }
549 
550  // Prints the max-change stats. Usually will be called at the end
551  // of the program. The nnet is only needed for structural information,
552  // to work out the component names.
553  void Print(const Nnet &nnet) const;
554 };
555 
556 
557 
558 } // namespace nnet3
559 } // namespace kaldi
560 
561 #endif
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void EvaluateComputationRequest(const Nnet &nnet, const ComputationRequest &request, std::vector< std::vector< bool > > *is_computable)
Given an nnet and a computation request, this function works out which requested outputs in the compu...
Definition: nnet-utils.cc:71
void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet)
This function sets the dropout proportion in all dropout components to dropout_proportion value...
Definition: nnet-utils.cc:509
void FindOrphanComponents(const Nnet &nnet, std::vector< int32 > *components)
This function finds a list of components that are never used, and outputs the integer comopnent index...
Definition: nnet-utils.cc:591
std::string PrintVectorPerUpdatableComponent(const Nnet &nnet, const VectorBase< BaseFloat > &vec)
This function is for printing, to a string, a vector with one element per updatable component of the ...
Definition: nnet-utils.cc:231
void ComponentDotProducts(const Nnet &nnet1, const Nnet &nnet2, VectorBase< BaseFloat > *dot_prod)
Returns dot products between two networks of the same structure (calls the DotProduct functions of th...
Definition: nnet-utils.cc:211
void FindOrphanNodes(const Nnet &nnet, std::vector< int32 > *nodes)
This function finds a list of nodes that are never used to compute any output, and outputs the intege...
Definition: nnet-utils.cc:607
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
Definition: nnet-utils.cc:536
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
std::vector< int32 > num_max_change_per_component_applied
Definition: nnet-utils.h:543
kaldi::int32 int32
void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet)
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConf...
Definition: nnet-utils.cc:1234
void VectorizeNnet(const Nnet &src, VectorBase< BaseFloat > *parameters)
Copies the nnet parameters to *params, whose dimension must be equal to NumParameters(src).
Definition: nnet-utils.cc:378
void ConvertRepeatedToBlockAffine(CompositeComponent *c_component)
Definition: nnet-utils.cc:447
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
Definition: nnet-utils.cc:1108
void UnVectorizeNnet(const VectorBase< BaseFloat > &parameters, Nnet *dest)
Copies the parameters from params to *dest.
Definition: nnet-utils.cc:401
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
Definition: nnet-utils.cc:432
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct &#39;n&#39; values in a training example...
Definition: nnet-utils.cc:2198
void ResetGenerators(Nnet *nnet)
This function calls &#39;ResetGenerator()&#39; on all components in &#39;nnet&#39; that inherit from class RandomComp...
Definition: nnet-utils.cc:582
float BaseFloat
Definition: kaldi-types.h:29
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
Definition: nnet-utils.cc:359
void AddNnetComponents(const Nnet &src, const Vector< BaseFloat > &alphas, BaseFloat scale, Nnet *dest)
Does *dest += alpha * src for updatable components (affects nnet parameters), and *dest += scale * sr...
Definition: nnet-utils.cc:322
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
Definition: nnet-utils.cc:2244
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
void SetLearningRate(BaseFloat learning_rate, Nnet *nnet)
Sets the underlying learning rate for all the components in the nnet to this value.
Definition: nnet-utils.cc:276
bool HasBatchnorm(const Nnet &nnet)
Returns true if nnet has at least one component of type BatchNormComponent.
Definition: nnet-utils.cc:527
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
std::string NnetInfo(const Nnet &nnet)
This function returns various info about the neural net.
Definition: nnet-utils.cc:492
BaseFloat DotProduct(const Nnet &nnet1, const Nnet &nnet2)
Returns dot product between two networks of the same structure (calls the DotProduct functions of the...
Definition: nnet-utils.cc:250
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
Definition: nnet-utils.cc:52
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
Definition: nnet-utils.cc:1147
void SetRequireDirectInput(bool b, Nnet *nnet)
Calls the corresponding function in any component of type StatisticsPoolingComponent; used as a way t...
Definition: nnet-utils.cc:303
int32 NumInputNodes(const Nnet &nnet)
returns the number of input nodes of this nnet.
Definition: nnet-utils.cc:43
void Print(const Fst< Arc > &fst, std::string message)
void PerturbParams(BaseFloat stddev, Nnet *nnet)
Calls PerturbParams (with the given stddev) on all updatable components of the nnet.
Definition: nnet-utils.cc:199
int32 NumOutputNodes(const Nnet &nnet)
returns the number of output nodes of this nnet.
Definition: nnet-utils.cc:35
bool NnetIsRecurrent(const Nnet &nnet)
Returns true if &#39;nnet&#39; has some kind of recurrency.
Definition: nnet-utils.cc:1441
This file contains class definitions for classes ForwardingDescriptor, SumDescriptor and Descriptor...
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation &#39;*nnet += scale * delta_nnet&#39;, while respecting any max-parameter-ch...
Definition: nnet-utils.cc:2106
int32 NumUpdatableComponents(const Nnet &dest)
Returns the number of updatable components in the nnet.
Definition: nnet-utils.cc:422
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).
Definition: nnet-utils.cc:349
MaxChangeStats(const Nnet &nnet)
Definition: nnet-utils.h:545
Config class for the CollapseModel function.
Definition: nnet-utils.h:240