28 const chain::ChainTrainingOptions &chain_config,
31 nnet_config_(nnet_config),
32 chain_config_(chain_config),
33 den_graph_(den_fst, nnet.OutputDim(
"output")),
35 compiler_(nnet, nnet_config_.optimize_config, nnet_config_.compiler_config),
36 deriv_nnet_owned_(true),
38 num_minibatches_processed_(0) {
44 KALDI_ERR <<
"If you set store_component_stats == true and " 45 <<
"compute_deriv == false, use the other constructor.";
52 const chain::ChainTrainingOptions &chain_config,
57 den_graph_(den_fst, nnet->OutputDim(
"output")),
70 KALDI_ERR <<
"GetDeriv() called when no derivatives were requested.";
99 bool use_xent_regularization = (
chain_config_.xent_regularize != 0.0),
100 use_xent_derivative =
false;
102 store_component_stats, use_xent_regularization,
103 use_xent_derivative, &request);
104 std::shared_ptr<const NnetComputation> computation =
compiler_.
Compile(request);
119 std::vector<NnetChainSupervision>::const_iterator iter = eg.
outputs.begin(),
121 for (; iter != end; ++iter) {
124 if (node_index < 0 ||
130 std::string xent_name = sup.
name +
"-xent";
133 nnet_output_deriv.
Resize(nnet_output.NumRows(), nnet_output.NumCols(),
136 xent_deriv.
Resize(nnet_output.NumRows(), nnet_output.NumCols(),
139 BaseFloat tot_like, tot_l2_term, tot_weight;
143 &tot_like, &tot_l2_term, &tot_weight,
145 NULL), (use_xent ? &xent_deriv : NULL));
181 unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
185 for (; iter != end; ++iter) {
186 const std::string &name = iter->first;
192 tot_objf = like + l2_term;
194 KALDI_LOG <<
"Overall log-probability for '" 196 << like <<
" per frame" 197 <<
", over " << info.
tot_weight <<
" frames.";
199 KALDI_LOG <<
"Overall log-probability for '" 201 << like <<
" + " << l2_term <<
" = " << tot_objf <<
" per frame" 202 <<
", over " << info.
tot_weight <<
" frames.";
212 const std::string &output_name)
const {
213 unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
216 return &(iter->second);
222 double tot_objectives = 0.0;
223 double tot_weight = 0.0;
224 unordered_map<std::string, ChainObjectiveInfo, StringHasher>::const_iterator
226 for (; iter != end; ++iter) {
227 tot_objectives += iter->second.tot_like + iter->second.tot_l2_term;
228 tot_weight += iter->second.tot_weight;
231 if (total_weight) *total_weight = tot_weight;
232 return tot_objectives;
236 const std::vector<std::string> node_names = nnet.
GetNodeNames();
237 for (std::vector<std::string>::const_iterator it = node_names.begin();
238 it != node_names.end(); ++it) {
241 it->find(
"-xent") != std::string::npos) {
249 const chain::ChainTrainingOptions &chain_config_in,
252 KALDI_LOG <<
"Recomputing stats on nnet (affects batch-norm)";
253 chain::ChainTrainingOptions chain_config(chain_config_in);
255 chain_config.xent_regularize == 0) {
260 chain_config.xent_regularize = 0.1;
267 for (
size_t i = 0;
i < egs.size();
i++)
double GetTotalObjective(double *tot_weight) const
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
chain::ChainTrainingOptions chain_config_
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
This class is for computing objective-function values in a nnet3+chain setup, for diagnostics...
chain::Supervision supervision
The supervision object, containing the FST.
std::vector< NnetIo > inputs
'inputs' contains the input to the network– normally just it has just one element called "input"...
CachingOptimizingCompiler compiler_
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
fst::StdVectorFst StdVectorFst
bool store_component_stats
static bool HasXentOutputs(const Nnet &nnet)
bool PrintTotalStats() const
std::string name
the name of the output in the neural net; in simple setups it will just be "output".
This file contains some miscellaneous functions dealing with class Nnet.
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
NnetChainComputeProb(const NnetComputeProbOptions &nnet_config, const chain::ChainTrainingOptions &chain_config, const fst::StdVectorFst &den_fst, const Nnet &nnet)
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
chain::DenominatorGraph den_graph_
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
std::vector< NnetChainSupervision > outputs
'outputs' contains the chain output supervision.
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
NnetComputeProbOptions nnet_config_
NnetComputeOptions compute_config
Matrix for CUDA computing.
class NnetComputer is responsible for executing the computation described in the "computation" object...
const Nnet & GetDeriv() const
#define KALDI_ASSERT(cond)
void Compute(const NnetChainExample &chain_eg)
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
const ChainObjectiveInfo * GetObjective(const std::string &output_name) const
const std::vector< std::string > & GetNodeNames() const
returns vector of node names (needed by some parsing code, for instance).
void ProcessOutputs(const NnetChainExample &chain_eg, NnetComputer *computer)
int32 num_minibatches_processed_
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...