31 den_graph_(den_fst, nnet->OutputDim(
"output")),
33 compiler_(*nnet, opts_.nnet_config.optimize_config,
34 opts_.nnet_config.compiler_config),
35 num_minibatches_processed_(0),
36 max_change_stats_(*nnet),
37 srand_seed_(
RandInt(0, 100000)) {
53 KALDI_WARN <<
"Could not open cached computation. " 54 "Probably this is the first training iteration.";
62 bool need_model_derivative =
true;
68 use_xent_regularization, need_model_derivative,
70 std::shared_ptr<const NnetComputation> computation =
compiler_.
Compile(request);
78 bool is_backstitch_step1 =
true;
83 is_backstitch_step1 =
false;
145 bool is_backstitch_step1) {
156 bool is_backstitch_step2 = !is_backstitch_step1;
160 BaseFloat max_change_scale, scale_adding;
161 if (is_backstitch_step1) {
183 max_change_scale, scale_adding,
nnet_,
186 if (is_backstitch_step1) {
193 if (!is_backstitch_step1) {
212 const std::string suffix = (is_backstitch_step2 ?
"_backstitch" :
"");
213 std::vector<NnetChainSupervision>::const_iterator iter = eg.
outputs.begin(),
215 for (; iter != end; ++iter) {
218 if (node_index < 0 ||
224 nnet_output.NumCols(),
228 std::string xent_name = sup.
name +
"-xent";
231 BaseFloat tot_objf, tot_l2_term, tot_weight;
235 &tot_objf, &tot_l2_term, &tot_weight,
237 (use_xent ? &xent_deriv : NULL));
246 objf_info_[xent_name + suffix].UpdateStats(xent_name + suffix,
249 tot_weight, xent_objf);
254 nnet_output_deriv.MulRowsVec(cu_deriv_weights);
264 tot_weight, tot_objf, tot_l2_term);
274 unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
278 for (; iter != end; ++iter) {
279 const std::string &name = iter->first;
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void TrainInternal(const NnetChainExample &eg, const NnetComputation &computation)
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Vector< BaseFloat > deriv_weights
This is a vector of per-frame weights, required to be between 0 and 1, that is applied to the derivat...
void TrainInternalBackstitch(const NnetChainExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
chain::Supervision supervision
The supervision object, containing the FST.
std::vector< NnetIo > inputs
'inputs' contains the input to the network– normally just it has just one element called "input"...
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
const NnetChainTrainingOptions opts_
BaseFloat l2_regularize_factor
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
fst::StdVectorFst StdVectorFst
chain::DenominatorGraph den_graph_
std::string name
the name of the output in the neural net; in simple setups it will just be "output".
int32 backstitch_training_interval
This file contains some miscellaneous functions dealing with class Nnet.
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct 'n' values in a training example...
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
void ResetGenerators(Nnet *nnet)
This function calls 'ResetGenerator()' on all components in 'nnet' that inherit from class RandomComp...
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
CachingOptimizingCompiler compiler_
bool store_component_stats
std::vector< NnetChainSupervision > outputs
'outputs' contains the chain output supervision.
MaxChangeStats max_change_stats_
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
int32 num_minibatches_processed_
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
NnetTrainerOptions nnet_config
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
bool PrintTotalStats(const std::string &output_name) const
void Train(const NnetChainExample &eg)
BaseFloat max_param_change
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void ReadCache(std::istream &is, bool binary)
Matrix for CUDA computing.
void WriteCache(std::ostream &os, bool binary)
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
class NnetComputer is responsible for executing the computation described in the "computation" object...
#define KALDI_ASSERT(cond)
NnetChainTrainer(const NnetChainTrainingOptions &config, const fst::StdVectorFst &den_fst, Nnet *nnet)
bool PrintTotalStats() const
chain::ChainTrainingOptions chain_config
void Print(const Nnet &nnet) const
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.
NnetComputeOptions compute_config
BaseFloat backstitch_training_scale
BaseFloat batchnorm_stats_scale
void MulRowsVec(const CuVectorBase< Real > &scale)
scale i'th row by scale[i]
bool zero_component_stats
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation '*nnet += scale * delta_nnet', while respecting any max-parameter-ch...
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
void ProcessOutputs(bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer)
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...