32 opts_(opts), tmodel_(tmodel), log_priors_(priors),
34 compiler_(*nnet, opts_.nnet_config.optimize_config),
35 num_minibatches_processed_(0) {
52 KALDI_LOG <<
"Read computation cache from " 55 KALDI_WARN <<
"Could not open cached computation. " 56 "Probably this is the first training iteration.";
64 bool need_model_derivative =
true;
70 use_xent_regularization,
71 need_model_derivative,
73 std::shared_ptr<const NnetComputation> computation =
compiler_.
Compile(request);
85 if (delta_nnet_ != NULL) {
89 std::sqrt(
DotProduct(*delta_nnet_, *delta_nnet_)) * scale;
91 if (param_delta - param_delta != 0.0) {
92 KALDI_WARN <<
"Infinite parameter change, will not apply.";
96 KALDI_LOG <<
"Parameter change too big: " << param_delta <<
" > " 113 std::vector<NnetDiscriminativeSupervision>::const_iterator iter = eg.
outputs.begin(),
115 for (; iter != end; ++iter) {
118 if (node_index < 0 ||
125 nnet_output.NumCols(),
129 std::string xent_name = sup.
name +
"-xent";
132 xent_deriv.
Resize(nnet_output.NumRows(), nnet_output.NumCols(),
147 (use_xent ? &xent_deriv : NULL));
155 if (xent_objf != xent_objf) {
157 xent_objf = default_objf;
164 objf_info_[xent_name].UpdateStats(xent_name,
"xent",
171 nnet_output_deriv.MulRowsVec(cu_deriv_weights);
197 for (; iter != end; ++iter) {
198 const std::string &name = iter->first;
199 const DiscriminativeObjectiveFunctionInfo &info = iter->second;
209 const std::string &output_name,
210 const std::string &criterion,
211 int32 minibatches_per_phase,
212 int32 minibatch_counter,
214 int32 phase = minibatch_counter / minibatches_per_phase;
215 if (phase != current_phase) {
217 PrintStatsForThisPhase(output_name, criterion, minibatches_per_phase);
218 current_phase = phase;
219 stats_this_phase.Reset();
221 stats_this_phase.Add(this_minibatch_stats);
222 stats.Add(this_minibatch_stats);
226 const std::string &output_name,
227 const std::string &criterion,
228 int32 minibatches_per_phase)
const {
229 int32 start_minibatch = current_phase * minibatches_per_phase,
230 end_minibatch = start_minibatch + minibatches_per_phase - 1;
232 BaseFloat objf = (stats_this_phase.TotalObjf(criterion) / stats_this_phase.tot_t_weighted);
233 KALDI_LOG <<
"Average objective function for '" << output_name
234 <<
"' for minibatches " << start_minibatch
235 <<
'-' << end_minibatch <<
" is " << objf
236 <<
" over " << stats_this_phase.tot_t_weighted <<
" frames.";
240 const std::string &criterion)
const {
241 BaseFloat objf = stats.TotalObjf(criterion) /stats.tot_t_weighted;
243 double avg_gradients = (stats.tot_num_count + stats.tot_den_count) /
244 stats.tot_t_weighted;
245 KALDI_LOG <<
"Average num+den count of stats is " << avg_gradients
246 <<
" per frame, over " 247 << stats.tot_t_weighted <<
" frames.";
248 if (stats.tot_l2_term != 0.0) {
249 KALDI_LOG <<
"Average l2 norm of output per frame is " 250 << (stats.tot_l2_term / stats.tot_t_weighted) <<
" over " 251 << stats.tot_t_weighted <<
" frames.";
255 KALDI_LOG <<
"Overall average objective function for '" << name <<
"' is " 256 << objf <<
" over " << stats.tot_t_weighted <<
" frames.";
257 KALDI_LOG <<
"[this line is to be parsed by a script:] " 258 << criterion <<
"-per-frame=" 260 return (stats.tot_t_weighted != 0.0);
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
A hashing function object for strings.
This file contains some miscellaneous functions dealing with class Nnet.
CachingOptimizingCompiler compiler_
int32 num_minibatches_processed_
~NnetDiscriminativeTrainer()
const NnetDiscriminativeOptions opts_
void ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer)
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
discriminative::DiscriminativeOptions discriminative_config
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
const TransitionModel & tmodel_
bool store_component_stats
void Train(const NnetDiscriminativeExample &eg)
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
BaseFloat xent_regularize
void ComputeDiscriminativeObjfAndDeriv(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
This function does forward-backward on the numerator and denominator lattices and computes derivates ...
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
bool PrintTotalStats() const
BaseFloat DotProduct(const Nnet &nnet1, const Nnet &nnet2)
Returns dot product between two networks of the same structure (calls the DotProduct functions of the...
NnetDiscriminativeTrainer(const NnetDiscriminativeOptions &config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, Nnet *nnet)
BaseFloat max_param_change
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void ReadCache(std::istream &is, bool binary)
Matrix for CUDA computing.
void WriteCache(std::ostream &os, bool binary)
unordered_map< std::string, DiscriminativeObjectiveFunctionInfo, StringHasher > objf_info_
class NnetComputer is responsible for executing the computation described in the "computation" object...
#define KALDI_ASSERT(cond)
Vector< BaseFloat > deriv_weights
std::vector< NnetIo > inputs
'inputs' contains the input to the network– normally just it has just one element called "input"...
bool PrintTotalStats(const std::string &output_name, const std::string &criterion) const
std::vector< NnetDiscriminativeSupervision > outputs
'outputs' contains the sequence output supervision.
void GetDiscriminativeComputationRequest(const Nnet &nnet, const NnetDiscriminativeExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetDiscriminativeExample and produces a ComputationRequest.
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
void UpdateStats(const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase, int32 minibatch_counter, discriminative::DiscriminativeObjectiveInfo stats)
Provides a vector abstraction class.
discriminative::DiscriminativeSupervision supervision
NnetComputeOptions compute_config
void PrintStatsForThisPhase(const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase) const
void MulRowsVec(const CuVectorBase< Real > &scale)
scale i'th row by scale[i]
bool zero_component_stats
NnetTrainerOptions nnet_config
CuVector< BaseFloat > log_priors_
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...