nnet-utils.h File Reference

This file contains some miscellaneous functions dealing with class Nnet. More...

Include dependency graph for nnet-utils.h:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Classes

struct  CollapseModelConfig
 Config class for the CollapseModel function. More...
 
struct  MaxChangeStats
 

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet3
 

Functions

void EvaluateComputationRequest (const Nnet &nnet, const ComputationRequest &request, std::vector< std::vector< bool > > *is_computable)
 Given an nnet and a computation request, this function works out which requested outputs in the computation request are computable; it outputs this information as a vector "is_computable" indexed by the same indexes as request.outputs. More...
 
int32 NumOutputNodes (const Nnet &nnet)
 returns the number of output nodes of this nnet. More...
 
int32 NumInputNodes (const Nnet &nnet)
 returns the number of input nodes of this nnet. More...
 
void PerturbParams (BaseFloat stddev, Nnet *nnet)
 Calls PerturbParams (with the given stddev) on all updatable components of the nnet. More...
 
BaseFloat DotProduct (const Nnet &nnet1, const Nnet &nnet2)
 Returns dot product between two networks of the same structure (calls the DotProduct functions of the Updatable components and sums up the return values). More...
 
void ComponentDotProducts (const Nnet &nnet1, const Nnet &nnet2, VectorBase< BaseFloat > *dot_prod)
 Returns dot products between two networks of the same structure (calls the DotProduct functions of the Updatable components and fill in the output vector). More...
 
std::string PrintVectorPerUpdatableComponent (const Nnet &nnet, const VectorBase< BaseFloat > &vec)
 This function is for printing, to a string, a vector with one element per updatable component of the nnet (e.g. More...
 
bool IsSimpleNnet (const Nnet &nnet)
 This function returns true if the nnet has the following properties: It has an output called "output" (other outputs are allowed but may be ignored). More...
 
void ZeroComponentStats (Nnet *nnet)
 Zeroes the component stats in all nonlinear components in the nnet. More...
 
void ComputeSimpleNnetContext (const Nnet &nnet, int32 *left_context, int32 *right_context)
 ComputeSimpleNnetContext computes the left-context and right-context of a nnet. More...
 
void SetLearningRate (BaseFloat learning_rate, Nnet *nnet)
 Sets the underlying learning rate for all the components in the nnet to this value. More...
 
void ScaleNnet (BaseFloat scale, Nnet *nnet)
 Scales the nnet parameters and stats by this scale. More...
 
void SetNnetAsGradient (Nnet *nnet)
 Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableComponent in nnet. More...
 
void SetRequireDirectInput (bool b, Nnet *nnet)
 Calls the corresponding function in any component of type StatisticsPoolingComponent; used as a way to compute the 'real' left-right context of networks including SatisticsPoolingComponent, which will give you the minimum chunk size they can consume. More...
 
void AddNnet (const Nnet &src, BaseFloat alpha, Nnet *dest)
 Does *dest += alpha * src (affects nnet parameters and stored stats). More...
 
void AddNnetComponents (const Nnet &src, const Vector< BaseFloat > &alphas, BaseFloat scale, Nnet *dest)
 Does *dest += alpha * src for updatable components (affects nnet parameters), and *dest += scale * src for other components (affects stored stats). More...
 
bool NnetIsRecurrent (const Nnet &nnet)
 Returns true if 'nnet' has some kind of recurrency. More...
 
int32 NumParameters (const Nnet &src)
 Returns the total of the number of parameters in the updatable components of the nnet. More...
 
void VectorizeNnet (const Nnet &src, VectorBase< BaseFloat > *params)
 Copies the nnet parameters to *params, whose dimension must be equal to NumParameters(src). More...
 
void UnVectorizeNnet (const VectorBase< BaseFloat > &params, Nnet *dest)
 Copies the parameters from params to *dest. More...
 
int32 NumUpdatableComponents (const Nnet &dest)
 Returns the number of updatable components in the nnet. More...
 
void FreezeNaturalGradient (bool freeze, Nnet *nnet)
 Controls if natural gradient will be updated. More...
 
void ConvertRepeatedToBlockAffine (Nnet *nnet)
 Convert all components of type RepeatedAffineComponent or NaturalGradientRepeatedAffineComponent to BlockAffineComponent in nnet. More...
 
std::string NnetInfo (const Nnet &nnet)
 This function returns various info about the neural net. More...
 
void SetDropoutProportion (BaseFloat dropout_proportion, Nnet *nnet)
 This function sets the dropout proportion in all dropout components to dropout_proportion value. More...
 
bool HasBatchnorm (const Nnet &nnet)
 Returns true if nnet has at least one component of type BatchNormComponent. More...
 
void SetBatchnormTestMode (bool test_mode, Nnet *nnet)
 This function affects only components of type BatchNormComponent. More...
 
void RecomputeStats (const std::vector< NnetExample > &egs, Nnet *nnet)
 This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs. More...
 
void SetDropoutTestMode (bool test_mode, Nnet *nnet)
 This function affects components of child-classes of RandomComponent. More...
 
void ResetGenerators (Nnet *nnet)
 This function calls 'ResetGenerator()' on all components in 'nnet' that inherit from class RandomComponent. More...
 
void FindOrphanComponents (const Nnet &nnet, std::vector< int32 > *components)
 This function finds a list of components that are never used, and outputs the integer comopnent indexes (you can use these to index nnet.GetComponentNames() to get their names). More...
 
void FindOrphanNodes (const Nnet &nnet, std::vector< int32 > *nodes)
 This function finds a list of nodes that are never used to compute any output, and outputs the integer node indexes (you can use these to index nnet.GetNodeNames() to get their names). More...
 
void CollapseModel (const CollapseModelConfig &config, Nnet *nnet)
 This function modifies the neural net for efficiency, in a way that suitable to be done in test time. More...
 
void ReadEditConfig (std::istream &config_file, Nnet *nnet)
 ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConfig(), but this consists of a sequence of operations to perform on an existing network, mostly modifying components. More...
 
bool UpdateNnetWithMaxChange (const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
 This function does the operation '*nnet += scale * delta_nnet', while respecting any max-parameter-change (max-param-change) specified in the updatable components, and also the global max-param-change specified as 'max_param_change'. More...
 
bool UpdateNnetWithMaxChange (const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, MaxChangeStats *stats)
 
void ApplyL2Regularization (const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
 This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange(). More...
 
void ScaleBatchnormStats (BaseFloat batchnorm_stats_scale, Nnet *nnet)
 This function scales the batchorm stats of any batchnorm components (components of type BatchNormComponent) in 'nnet' by the scale 'batchnorm_stats_scale'. More...
 
void ConstrainOrthonormal (Nnet *nnet)
 This function, to be called after processing every minibatch, is responsible for enforcing the orthogonality constraint for any components of type LinearComponent or inheriting from AffineComponent that have the "orthonormal_constraint" value set. More...
 
void ConsolidateMemory (Nnet *nnet)
 This just calls ConsolidateMemory() on all the components of the nnet. More...
 
int32 GetNumNvalues (const std::vector< NnetIo > &io_vec, bool exhaustive)
 This utility function can be used to obtain the number of distinct 'n' values in a training example. More...
 

Detailed Description

This file contains some miscellaneous functions dealing with class Nnet.

Definition in file nnet-utils.h.