31 compiler_(*nnet, config_.optimize_config, config_.compiler_config),
32 num_minibatches_processed_(0),
33 max_change_stats_(*nnet),
34 srand_seed_(
RandInt(0, 100000)) {
50 KALDI_WARN <<
"Could not open cached computation. " 51 "Probably this is the first training iteration.";
58 bool need_model_derivative =
true;
63 std::shared_ptr<const NnetComputation> computation =
compiler_.
Compile(request);
71 bool is_backstitch_step1 =
true;
76 is_backstitch_step1 =
false;
133 bool is_backstitch_step1) {
143 bool is_backstitch_step2 = !is_backstitch_step1;
147 BaseFloat max_change_scale, scale_adding;
148 if (is_backstitch_step1) {
170 max_change_scale, scale_adding,
nnet_,
173 if (is_backstitch_step1) {
180 if (!is_backstitch_step1) {
198 const std::string suffix = (is_backstitch_step2 ?
"_backstitch" :
"");
199 std::vector<NnetIo>::const_iterator iter = eg.
io.begin(),
201 for (; iter != end; ++iter) {
208 bool supply_deriv =
true;
210 supply_deriv, computer,
211 &tot_weight, &tot_objf);
215 tot_weight, tot_objf);
221 unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
224 std::vector<std::pair<std::string, const ObjectiveFunctionInfo*> > all_pairs;
225 for (; iter != end; ++iter)
226 all_pairs.push_back(std::pair<std::string, const ObjectiveFunctionInfo*>(
227 iter->first, &(iter->second)));
230 std::sort(all_pairs.begin(), all_pairs.end());
232 for (
size_t i = 0;
i < all_pairs.size();
i++) {
233 const std::string &name = all_pairs[
i].first;
243 const std::string &output_name,
244 int32 minibatches_per_phase,
245 int32 minibatch_counter,
249 int32 phase = minibatch_counter / minibatches_per_phase;
250 if (phase != current_phase) {
252 PrintStatsForThisPhase(output_name, minibatches_per_phase,
254 current_phase = phase;
255 tot_weight_this_phase = 0.0;
256 tot_objf_this_phase = 0.0;
257 tot_aux_objf_this_phase = 0.0;
258 minibatches_this_phase = 0;
260 minibatches_this_phase++;
261 tot_weight_this_phase += this_minibatch_weight;
262 tot_objf_this_phase += this_minibatch_tot_objf;
263 tot_aux_objf_this_phase += this_minibatch_tot_aux_objf;
264 tot_weight += this_minibatch_weight;
265 tot_objf += this_minibatch_tot_objf;
266 tot_aux_objf += this_minibatch_tot_aux_objf;
270 const std::string &output_name,
271 int32 minibatches_per_phase,
273 int32 start_minibatch = current_phase * minibatches_per_phase,
274 end_minibatch = phase * minibatches_per_phase - 1;
276 if (tot_aux_objf_this_phase == 0.0) {
277 if (minibatches_per_phase == minibatches_this_phase) {
278 KALDI_LOG <<
"Average objective function for '" << output_name
279 <<
"' for minibatches " << start_minibatch
280 <<
'-' << end_minibatch <<
" is " 281 << (tot_objf_this_phase / tot_weight_this_phase) <<
" over " 282 << tot_weight_this_phase <<
" frames.";
284 KALDI_LOG <<
"Average objective function for '" << output_name
285 <<
" using " << minibatches_this_phase
286 <<
" minibatches in minibatch range " << start_minibatch
287 <<
'-' << end_minibatch <<
" is " 288 << (tot_objf_this_phase / tot_weight_this_phase) <<
" over " 289 << tot_weight_this_phase <<
" frames.";
292 BaseFloat objf = (tot_objf_this_phase / tot_weight_this_phase),
293 aux_objf = (tot_aux_objf_this_phase / tot_weight_this_phase),
294 sum_objf = objf + aux_objf;
295 if (minibatches_per_phase == minibatches_this_phase) {
296 KALDI_LOG <<
"Average objective function for '" << output_name
297 <<
"' for minibatches " << start_minibatch
298 <<
'-' << end_minibatch <<
" is " 299 << objf <<
" + " << aux_objf <<
" = " << sum_objf
300 <<
" over " << tot_weight_this_phase <<
" frames.";
302 KALDI_LOG <<
"Average objective function for '" << output_name
303 <<
"' using " << minibatches_this_phase
304 <<
" minibatches in minibatch range " << start_minibatch
305 <<
'-' << end_minibatch <<
" is " 306 << objf <<
" + " << aux_objf <<
" = " << sum_objf
307 <<
" over " << tot_weight_this_phase <<
" frames.";
313 BaseFloat objf = (tot_objf / tot_weight),
314 aux_objf = (tot_aux_objf / tot_weight),
315 sum_objf = objf + aux_objf;
316 if (tot_aux_objf == 0.0) {
317 KALDI_LOG <<
"Overall average objective function for '" << name <<
"' is " 318 << (tot_objf / tot_weight) <<
" over " << tot_weight <<
" frames.";
320 KALDI_LOG <<
"Overall average objective function for '" << name <<
"' is " 321 << objf <<
" + " << aux_objf <<
" = " << sum_objf
322 <<
" over " << tot_weight <<
" frames.";
324 KALDI_LOG <<
"[this line is to be parsed by a script:] " 325 <<
"log-prob-per-frame=" 327 return (tot_weight != 0.0);
341 const std::string &output_name,
349 KALDI_ERR <<
"Nnet versus example output dimension (num-classes) " 350 <<
"mismatch for '" << output_name <<
"': " << output.
NumCols()
351 <<
" (nnet) vs. " << supervision.
NumCols() <<
" (egs)\n";
353 switch (objective_type) {
356 switch (supervision.
Type()) {
363 *tot_weight = cu_post.
Sum();
377 *tot_weight = cu_post.
Sum();
388 *tot_weight = cu_post.
Sum();
403 diff.AddMat(-1.0, output);
404 *tot_weight = diff.NumRows();
411 KALDI_ERR <<
"Objective function type " << objective_type
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
void CopyToMat(CuMatrixBase< OtherReal > *dest, MatrixTransposeType trans=kNoTrans) const
void GetMatrix(Matrix< BaseFloat > *mat) const
Outputs the contents as a matrix.
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
void UpdateStats(const std::string &output_name, int32 minibatches_per_phase, int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf=0.0)
void TrainInternalBackstitch(const NnetExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
GeneralMatrix features
The features or labels.
void Train(const NnetExample &eg)
const Matrix< BaseFloat > & GetFullMatrix() const
Returns the contents as a Matrix<BaseFloat>.
BaseFloat l2_regularize_factor
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
ObjectiveType objective_type
int32 backstitch_training_interval
This file contains some miscellaneous functions dealing with class Nnet.
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct 'n' values in a training example...
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
void ResetGenerators(Nnet *nnet)
This function calls 'ResetGenerator()' on all components in 'nnet' that inherit from class RandomComp...
MatrixIndexT NumCols() const
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
MaxChangeStats max_change_stats_
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
GeneralMatrixType Type() const
Returns the type of the matrix: kSparseMatrix, kCompressedMatrix or kFullMatrix.
bool store_component_stats
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
void Swap(Matrix< Real > *mat)
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
CachingOptimizingCompiler compiler_
NnetTrainer(const NnetTrainerOptions &config, Nnet *nnet)
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
bool PrintTotalStats() const
bool PrintTotalStats(const std::string &output_name) const
const NnetTrainerOptions config_
void CopyFromGeneralMat(const GeneralMatrix &src, MatrixTransposeType trans=kNoTrans)
BaseFloat max_param_change
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it's for the con...
void ReadCache(std::istream &is, bool binary)
Matrix for CUDA computing.
MatrixIndexT NumCols() const
void WriteCache(std::ostream &os, bool binary)
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
class NnetComputer is responsible for executing the computation described in the "computation" object...
#define KALDI_ASSERT(cond)
MatrixIndexT NumRows() const
void TrainInternal(const NnetExample &eg, const NnetComputation &computation)
void Print(const Nnet &nnet) const
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
Real TraceMatSmat(const MatrixBase< Real > &A, const SparseMatrix< Real > &B, MatrixTransposeType trans)
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
const SparseMatrix< BaseFloat > & GetSparseMatrix() const
Returns the contents as a SparseMatrix.
MatrixIndexT NumRows() const
Dimensions.
void PrintStatsForThisPhase(const std::string &output_name, int32 minibatches_per_phase, int32 phase) const
NnetComputeOptions compute_config
std::vector< NnetIo > io
"io" contains the input and output.
int32 num_minibatches_processed_
BaseFloat backstitch_training_scale
BaseFloat batchnorm_stats_scale
bool zero_component_stats
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation '*nnet += scale * delta_nnet', while respecting any max-parameter-ch...
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...
void ProcessOutputs(bool is_backstitch_step2, const NnetExample &eg, NnetComputer *computer)