21 #ifndef KALDI_NNET3_NNET_TRAINING_H_    22 #define KALDI_NNET3_NNET_TRAINING_H_    52       zero_component_stats(true),
    53       store_component_stats(true),
    55       debug_computation(false),
    57       l2_regularize_factor(1.0),
    58       backstitch_training_scale(0.0),
    59       backstitch_training_interval(1),
    60       batchnorm_stats_scale(0.8),
    61       binary_write_cache(true),
    62       max_param_change(2.0) { }
    64     opts->
Register(
"store-component-stats", &store_component_stats,
    65                    "If true, store activations and derivatives for nonlinear "    66                    "components during training.");
    67     opts->
Register(
"zero-component-stats", &zero_component_stats,
    68                    "If both this and --store-component-stats are true, then "    69                    "the component stats are zeroed before training.");
    70     opts->
Register(
"print-interval", &print_interval, 
"Interval (measured in "    71                    "minibatches) after which we print out objective function "    73     opts->
Register(
"max-param-change", &max_param_change, 
"The maximum change in "    74                    "parameters allowed per minibatch, measured in Euclidean norm "    75                    "over the entire model (change will be clipped to this value)");
    76     opts->
Register(
"momentum", &momentum, 
"Momentum constant to apply during "    77                    "training (help stabilize update).  e.g. 0.9.  Note: we "    78                    "automatically multiply the learning rate by (1-momenum) "    79                    "so that the 'effective' learning rate is the same as "    80                    "before (because momentum would normally increase the "    81                    "effective learning rate by 1/(1-momentum))");
    82     opts->
Register(
"l2-regularize-factor", &l2_regularize_factor, 
"Factor that "    83                    "affects the strength of l2 regularization on model "    84                    "parameters.  The primary way to specify this type of "    85                    "l2 regularization is via the 'l2-regularize'"    86                    "configuration value at the config-file level. "    87                    " --l2-regularize-factor will be multiplied by the component-level "    88                    "l2-regularize values and can be used to correct for effects "    89                    "related to parallelization by model averaging.");
    90     opts->
Register(
"batchnorm-stats-scale", &batchnorm_stats_scale,
    91                    "Factor by which we scale down the accumulated stats of batchnorm "    92                    "layers after processing each minibatch.  Ensure that the final "    93                    "model we write out has batchnorm stats that are fairly fresh.");
    94     opts->
Register(
"backstitch-training-scale", &backstitch_training_scale,
    95                    "backstitch training factor. "    96                    "if 0 then in the normal training mode. It is referred as "    97                    "'\\alpha' in our publications.");
    98     opts->
Register(
"backstitch-training-interval",
    99                    &backstitch_training_interval,
   100                    "do backstitch training with the specified interval of "   101                    "minibatches. It is referred as 'n' in our publications.");
   102     opts->
Register(
"read-cache", &read_cache, 
"The location from which to read "   103                    "the cached computation.");
   104     opts->
Register(
"write-cache", &write_cache, 
"The location to which to write "   105                    "the cached computation.");
   106     opts->
Register(
"binary-write-cache", &binary_write_cache, 
"Write "   107                    "computation cache in binary mode");
   111     optimize_config.
Register(&optimization_opts);
   113     compiler_config.
Register(&compiler_opts);
   116     compute_config.
Register(&compute_opts);
   140       minibatches_this_phase(0),
   141       tot_weight(0.0), tot_objf(0.0), tot_aux_objf(0.0),
   142       tot_weight_this_phase(0.0), tot_objf_this_phase(0.0),
   143       tot_aux_objf_this_phase(0.0) { }
   149   void UpdateStats(
const std::string &output_name,
   150                    int32 minibatches_per_phase,
   151                    int32 minibatch_counter,
   154                    BaseFloat this_minibatch_tot_aux_objf = 0.0);
   160   void PrintStatsForThisPhase(
const std::string &output_name,
   161                               int32 minibatches_per_phase,
   164   bool PrintTotalStats(
const std::string &output_name) 
const;
   189   bool PrintTotalStats() 
const;
   200   void TrainInternalBackstitch(
const NnetExample &eg,
   202                                bool is_backstitch_step1);
   204   void ProcessOutputs(
bool is_backstitch_step2, 
const NnetExample &eg,
   222   unordered_map<std::string, ObjectiveFunctionInfo, StringHasher> 
objf_info_;
   266                               const std::string &output_name,
   277 #endif // KALDI_NNET3_NNET_TRAINING_H_ NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
 
void Register(OptionsItf *opts)
 
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
 
double tot_aux_objf_this_phase
 
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
 
NnetOptimizeOptions optimize_config
 
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
 
void Register(OptionsItf *opts)
 
void Register(OptionsItf *opts)
 
BaseFloat l2_regularize_factor
 
CachingOptimizingCompilerOptions compiler_config
 
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
 
int32 backstitch_training_interval
 
This file contains some miscellaneous functions dealing with class Nnet. 
 
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
 
MaxChangeStats max_change_stats_
 
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
 
bool store_component_stats
 
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
 
CachingOptimizingCompiler compiler_
 
int32 minibatches_this_phase
 
const NnetTrainerOptions config_
 
BaseFloat max_param_change
 
void Register(OptionsItf *opts)
 
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it's for the con...
 
double tot_objf_this_phase
 
class NnetComputer is responsible for executing the computation described in the "computation" object...
 
double tot_weight_this_phase
 
This class is for single-threaded training of neural nets using standard objective functions such as ...
 
NnetComputeOptions compute_config
 
int32 num_minibatches_processed_
 
BaseFloat backstitch_training_scale
 
BaseFloat batchnorm_stats_scale
 
bool zero_component_stats