20 #ifndef KALDI_NNET3_NNET_UTILS_H_ 21 #define KALDI_NNET3_NNET_UTILS_H_ 46 const ComputationRequest &request,
47 std::vector<std::vector<bool> > *is_computable);
73 VectorBase<BaseFloat> *dot_prod);
80 const VectorBase<BaseFloat> &vec);
105 int32 *right_context);
149 VectorBase<BaseFloat> *params);
172 std::string
NnetInfo(
const Nnet &nnet);
195 void RecomputeStats(
const std::vector<NnetExample> &egs, Nnet *nnet);
246 collapse_batchnorm(false),
247 collapse_affine(true),
248 collapse_scale(true) { }
390 num_max_change_per_component_applied,
391 int32 *num_max_change_global_applied);
546 num_max_change_global_applied(0),
547 num_minibatches_processed(0),
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
void EvaluateComputationRequest(const Nnet &nnet, const ComputationRequest &request, std::vector< std::vector< bool > > *is_computable)
Given an nnet and a computation request, this function works out which requested outputs in the compu...
void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet)
This function sets the dropout proportion in all dropout components to dropout_proportion value...
void FindOrphanComponents(const Nnet &nnet, std::vector< int32 > *components)
This function finds a list of components that are never used, and outputs the integer comopnent index...
std::string PrintVectorPerUpdatableComponent(const Nnet &nnet, const VectorBase< BaseFloat > &vec)
This function is for printing, to a string, a vector with one element per updatable component of the ...
int32 num_max_change_global_applied
void ComponentDotProducts(const Nnet &nnet1, const Nnet &nnet2, VectorBase< BaseFloat > *dot_prod)
Returns dot products between two networks of the same structure (calls the DotProduct functions of th...
int32 num_minibatches_processed
void FindOrphanNodes(const Nnet &nnet, std::vector< int32 > *nodes)
This function finds a list of nodes that are never used to compute any output, and outputs the intege...
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
std::vector< int32 > num_max_change_per_component_applied
void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet)
ReadEditConfig() reads a file with a similar-looking format to the config file read by Nnet::ReadConf...
void VectorizeNnet(const Nnet &src, VectorBase< BaseFloat > *parameters)
Copies the nnet parameters to *params, whose dimension must be equal to NumParameters(src).
void ConvertRepeatedToBlockAffine(CompositeComponent *c_component)
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
void UnVectorizeNnet(const VectorBase< BaseFloat > ¶meters, Nnet *dest)
Copies the parameters from params to *dest.
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct 'n' values in a training example...
void ResetGenerators(Nnet *nnet)
This function calls 'ResetGenerator()' on all components in 'nnet' that inherit from class RandomComp...
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
void AddNnetComponents(const Nnet &src, const Vector< BaseFloat > &alphas, BaseFloat scale, Nnet *dest)
Does *dest += alpha * src for updatable components (affects nnet parameters), and *dest += scale * sr...
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
void SetLearningRate(BaseFloat learning_rate, Nnet *nnet)
Sets the underlying learning rate for all the components in the nnet to this value.
bool HasBatchnorm(const Nnet &nnet)
Returns true if nnet has at least one component of type BatchNormComponent.
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
std::string NnetInfo(const Nnet &nnet)
This function returns various info about the neural net.
BaseFloat DotProduct(const Nnet &nnet1, const Nnet &nnet2)
Returns dot product between two networks of the same structure (calls the DotProduct functions of the...
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
void SetRequireDirectInput(bool b, Nnet *nnet)
Calls the corresponding function in any component of type StatisticsPoolingComponent; used as a way t...
int32 NumInputNodes(const Nnet &nnet)
returns the number of input nodes of this nnet.
void Print(const Fst< Arc > &fst, std::string message)
void PerturbParams(BaseFloat stddev, Nnet *nnet)
Calls PerturbParams (with the given stddev) on all updatable components of the nnet.
int32 NumOutputNodes(const Nnet &nnet)
returns the number of output nodes of this nnet.
bool NnetIsRecurrent(const Nnet &nnet)
Returns true if 'nnet' has some kind of recurrency.
This file contains class definitions for classes ForwardingDescriptor, SumDescriptor and Descriptor...
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation '*nnet += scale * delta_nnet', while respecting any max-parameter-ch...
int32 NumUpdatableComponents(const Nnet &dest)
Returns the number of updatable components in the nnet.
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).
MaxChangeStats(const Nnet &nnet)
Config class for the CollapseModel function.