This file contains some miscellaneous functions dealing with class Nnet. More...
#include "base/kaldi-common.h"
#include "util/kaldi-io.h"
#include "matrix/matrix-lib.h"
#include "nnet3/nnet-common.h"
#include "nnet3/nnet-component-itf.h"
#include "nnet3/nnet-descriptor.h"
#include "nnet3/nnet-computation.h"
#include "nnet3/nnet-example.h"
Go to the source code of this file.
Classes | |
struct | CollapseModelConfig |
Config class for the CollapseModel function. More... | |
struct | MaxChangeStats |
Namespaces | |
kaldi | |
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference: | |
kaldi::nnet3 | |
Functions | |
void | EvaluateComputationRequest (const Nnet &nnet, const ComputationRequest &request, std::vector< std::vector< bool > > *is_computable) |
Given an nnet and a computation request, this function works out which requested outputs in the computation request are computable; it outputs this information as a vector "is_computable" indexed by the same indexes as request.outputs. More... | |
int32 | NumOutputNodes (const Nnet &nnet) |
returns the number of output nodes of this nnet. More... | |
int32 | NumInputNodes (const Nnet &nnet) |
returns the number of input nodes of this nnet. More... | |
void | PerturbParams (BaseFloat stddev, Nnet *nnet) |
Calls PerturbParams (with the given stddev) on all updatable components of the nnet. More... | |
BaseFloat | DotProduct (const Nnet &nnet1, const Nnet &nnet2) |
Returns dot product between two networks of the same structure (calls the DotProduct functions of the Updatable components and sums up the return values). More... | |
void | ComponentDotProducts (const Nnet &nnet1, const Nnet &nnet2, VectorBase< BaseFloat > *dot_prod) |
Returns dot products between two networks of the same structure (calls the DotProduct functions of the Updatable components and fill in the output vector). More... | |
std::string | PrintVectorPerUpdatableComponent (const Nnet &nnet, const VectorBase< BaseFloat > &vec) |
This function is for printing, to a string, a vector with one element per updatable component of the nnet (e.g. More... | |
bool | IsSimpleNnet (const Nnet &nnet) |
This function returns true if the nnet has the following properties: It has an output called "output" (other outputs are allowed but may be ignored). More... | |
void | ZeroComponentStats (Nnet *nnet) |
Zeroes the component stats in all nonlinear components in the nnet. More... | |
void | ComputeSimpleNnetContext (const Nnet &nnet, int32 *left_context, int32 *right_context) |
ComputeSimpleNnetContext computes the left-context and right-context of a nnet. More... | |
void | SetLearningRate (BaseFloat learning_rate, Nnet *nnet) |
Sets the underlying learning rate for all the components in the nnet to this value. More... | |
void | ScaleNnet (BaseFloat scale, Nnet *nnet) |
Scales the nnet parameters and stats by this scale. More... | |
void | SetNnetAsGradient (Nnet *nnet) |
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableComponent in nnet. More... | |
void | SetRequireDirectInput (bool b, Nnet *nnet) |
Calls the corresponding function in any component of type StatisticsPoolingComponent; used as a way to compute the 'real' left-right context of networks including SatisticsPoolingComponent, which will give you the minimum chunk size they can consume. More... | |
void | AddNnet (const Nnet &src, BaseFloat alpha, Nnet *dest) |
Does *dest += alpha * src (affects nnet parameters and stored stats). More... | |
void | AddNnetComponents (const Nnet &src, const Vector< BaseFloat > &alphas, BaseFloat scale, Nnet *dest) |
Does *dest += alpha * src for updatable components (affects nnet parameters), and *dest += scale * src for other components (affects stored stats). More... | |
bool | NnetIsRecurrent (const Nnet &nnet) |
Returns true if 'nnet' has some kind of recurrency. More... | |
int32 | NumParameters (const Nnet &src) |
Returns the total of the number of parameters in the updatable components of the nnet. More... | |
void | VectorizeNnet (const Nnet &src, VectorBase< BaseFloat > *params) |
Copies the nnet parameters to *params, whose dimension must be equal to NumParameters(src). More... | |
void | UnVectorizeNnet (const VectorBase< BaseFloat > ¶ms, Nnet *dest) |
Copies the parameters from params to *dest. More... | |
int32 | NumUpdatableComponents (const Nnet &dest) |
Returns the number of updatable components in the nnet. More... | |
void | FreezeNaturalGradient (bool freeze, Nnet *nnet) |
Controls if natural gradient will be updated. More... | |
void | ConvertRepeatedToBlockAffine (Nnet *nnet) |
Convert all components of type RepeatedAffineComponent or NaturalGradientRepeatedAffineComponent to BlockAffineComponent in nnet. More... | |
std::string | NnetInfo (const Nnet &nnet) |
This function returns various info about the neural net. More... | |
void | SetDropoutProportion (BaseFloat dropout_proportion, Nnet *nnet) |
This function sets the dropout proportion in all dropout components to dropout_proportion value. More... | |
bool | HasBatchnorm (const Nnet &nnet) |
Returns true if nnet has at least one component of type BatchNormComponent. More... | |
void | SetBatchnormTestMode (bool test_mode, Nnet *nnet) |
This function affects only components of type BatchNormComponent. More... | |
void | RecomputeStats (const std::vector< NnetExample > &egs, Nnet *nnet) |
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs. More... | |
void | SetDropoutTestMode (bool test_mode, Nnet *nnet) |
This function affects components of child-classes of RandomComponent. More... | |
void | ResetGenerators (Nnet *nnet) |
This function calls 'ResetGenerator()' on all components in 'nnet' that inherit from class RandomComponent. More... | |
void | FindOrphanComponents (const Nnet &nnet, std::vector< int32 > *components) |
This function finds a list of components that are never used, and outputs the integer comopnent indexes (you can use these to index nnet.GetComponentNames() to get their names). More... | |
void | FindOrphanNodes (const Nnet &nnet, std::vector< int32 > *nodes) |
This function finds a list of nodes that are never used to compute any output, and outputs the integer node indexes (you can use these to index nnet.GetNodeNames() to get their names). More... | |
void | CollapseModel (const CollapseModelConfig &config, Nnet *nnet) |
This function modifies the neural net for efficiency, in a way that suitable to be done in test time. More... | |
void | ReadEditConfig (std::istream &config_file, Nnet *nnet) |
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConfig(), but this consists of a sequence of operations to perform on an existing network, mostly modifying components. More... | |
bool | UpdateNnetWithMaxChange (const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied) |
This function does the operation '*nnet += scale * delta_nnet', while respecting any max-parameter-change (max-param-change) specified in the updatable components, and also the global max-param-change specified as 'max_param_change'. More... | |
bool | UpdateNnetWithMaxChange (const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, MaxChangeStats *stats) |
void | ApplyL2Regularization (const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet) |
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange(). More... | |
void | ScaleBatchnormStats (BaseFloat batchnorm_stats_scale, Nnet *nnet) |
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComponent) in 'nnet' by the scale 'batchnorm_stats_scale'. More... | |
void | ConstrainOrthonormal (Nnet *nnet) |
This function, to be called after processing every minibatch, is responsible for enforcing the orthogonality constraint for any components of type LinearComponent or inheriting from AffineComponent that have the "orthonormal_constraint" value set. More... | |
void | ConsolidateMemory (Nnet *nnet) |
This just calls ConsolidateMemory() on all the components of the nnet. More... | |
int32 | GetNumNvalues (const std::vector< NnetIo > &io_vec, bool exhaustive) |
This utility function can be used to obtain the number of distinct 'n' values in a training example. More... | |
This file contains some miscellaneous functions dealing with class Nnet.
Definition in file nnet-utils.h.