21 #ifndef KALDI_NNET3_NNET_TRAINING_H_ 22 #define KALDI_NNET3_NNET_TRAINING_H_ 52 zero_component_stats(true),
53 store_component_stats(true),
55 debug_computation(false),
57 l2_regularize_factor(1.0),
58 backstitch_training_scale(0.0),
59 backstitch_training_interval(1),
60 batchnorm_stats_scale(0.8),
61 binary_write_cache(true),
62 max_param_change(2.0) { }
64 opts->
Register(
"store-component-stats", &store_component_stats,
65 "If true, store activations and derivatives for nonlinear " 66 "components during training.");
67 opts->
Register(
"zero-component-stats", &zero_component_stats,
68 "If both this and --store-component-stats are true, then " 69 "the component stats are zeroed before training.");
70 opts->
Register(
"print-interval", &print_interval,
"Interval (measured in " 71 "minibatches) after which we print out objective function " 73 opts->
Register(
"max-param-change", &max_param_change,
"The maximum change in " 74 "parameters allowed per minibatch, measured in Euclidean norm " 75 "over the entire model (change will be clipped to this value)");
76 opts->
Register(
"momentum", &momentum,
"Momentum constant to apply during " 77 "training (help stabilize update). e.g. 0.9. Note: we " 78 "automatically multiply the learning rate by (1-momenum) " 79 "so that the 'effective' learning rate is the same as " 80 "before (because momentum would normally increase the " 81 "effective learning rate by 1/(1-momentum))");
82 opts->
Register(
"l2-regularize-factor", &l2_regularize_factor,
"Factor that " 83 "affects the strength of l2 regularization on model " 84 "parameters. The primary way to specify this type of " 85 "l2 regularization is via the 'l2-regularize'" 86 "configuration value at the config-file level. " 87 " --l2-regularize-factor will be multiplied by the component-level " 88 "l2-regularize values and can be used to correct for effects " 89 "related to parallelization by model averaging.");
90 opts->
Register(
"batchnorm-stats-scale", &batchnorm_stats_scale,
91 "Factor by which we scale down the accumulated stats of batchnorm " 92 "layers after processing each minibatch. Ensure that the final " 93 "model we write out has batchnorm stats that are fairly fresh.");
94 opts->
Register(
"backstitch-training-scale", &backstitch_training_scale,
95 "backstitch training factor. " 96 "if 0 then in the normal training mode. It is referred as " 97 "'\\alpha' in our publications.");
98 opts->
Register(
"backstitch-training-interval",
99 &backstitch_training_interval,
100 "do backstitch training with the specified interval of " 101 "minibatches. It is referred as 'n' in our publications.");
102 opts->
Register(
"read-cache", &read_cache,
"The location from which to read " 103 "the cached computation.");
104 opts->
Register(
"write-cache", &write_cache,
"The location to which to write " 105 "the cached computation.");
106 opts->
Register(
"binary-write-cache", &binary_write_cache,
"Write " 107 "computation cache in binary mode");
111 optimize_config.
Register(&optimization_opts);
113 compiler_config.
Register(&compiler_opts);
116 compute_config.
Register(&compute_opts);
140 minibatches_this_phase(0),
141 tot_weight(0.0), tot_objf(0.0), tot_aux_objf(0.0),
142 tot_weight_this_phase(0.0), tot_objf_this_phase(0.0),
143 tot_aux_objf_this_phase(0.0) { }
149 void UpdateStats(
const std::string &output_name,
150 int32 minibatches_per_phase,
151 int32 minibatch_counter,
154 BaseFloat this_minibatch_tot_aux_objf = 0.0);
160 void PrintStatsForThisPhase(
const std::string &output_name,
161 int32 minibatches_per_phase,
164 bool PrintTotalStats(
const std::string &output_name)
const;
189 bool PrintTotalStats()
const;
200 void TrainInternalBackstitch(
const NnetExample &eg,
202 bool is_backstitch_step1);
204 void ProcessOutputs(
bool is_backstitch_step2,
const NnetExample &eg,
222 unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>
objf_info_;
266 const std::string &output_name,
277 #endif // KALDI_NNET3_NNET_TRAINING_H_ NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void Register(OptionsItf *opts)
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
double tot_aux_objf_this_phase
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
NnetOptimizeOptions optimize_config
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void Register(OptionsItf *opts)
void Register(OptionsItf *opts)
BaseFloat l2_regularize_factor
CachingOptimizingCompilerOptions compiler_config
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
int32 backstitch_training_interval
This file contains some miscellaneous functions dealing with class Nnet.
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
MaxChangeStats max_change_stats_
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
bool store_component_stats
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
CachingOptimizingCompiler compiler_
int32 minibatches_this_phase
const NnetTrainerOptions config_
BaseFloat max_param_change
void Register(OptionsItf *opts)
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it's for the con...
double tot_objf_this_phase
class NnetComputer is responsible for executing the computation described in the "computation" object...
double tot_weight_this_phase
This class is for single-threaded training of neural nets using standard objective functions such as ...
NnetComputeOptions compute_config
int32 num_minibatches_processed_
BaseFloat backstitch_training_scale
BaseFloat batchnorm_stats_scale
bool zero_component_stats