30 deriv_nnet_owned_(true),
32 compiler_(nnet, config_.optimize_config, config_.compiler_config),
33 num_minibatches_processed_(0) {
39 KALDI_ERR <<
"If you set store_component_stats == true and " 40 <<
"compute_deriv == false, use the other constructor.";
60 KALDI_ERR <<
"GetDeriv() called when no derivatives were requested.";
84 store_component_stats,
86 std::shared_ptr<const NnetComputation> computation =
compiler_.
Compile(request);
99 std::vector<NnetIo>::const_iterator iter = eg.
io.begin(),
101 for (; iter != end; ++iter) {
110 KALDI_ERR <<
"Nnet versus example output dimension (num-classes) " 111 <<
"mismatch for '" << io.
name <<
"': " << output.
NumCols()
118 supply_deriv, computer,
119 &tot_weight, &tot_objf);
136 &tot_weight, &tot_accuracy,
156 for (; iter != end; ++iter) {
157 const std::string &name = iter->first;
161 const SimpleObjectiveInfo &info = iter->second;
163 << (obj_type ==
kLinear ?
"log-likelihood" :
"objective")
164 <<
" for '" << name <<
"' is " 166 <<
", over " << info.
tot_weight <<
" frames.";
177 for (; iter != end; ++iter) {
178 const std::string &name = iter->first;
179 const PerDimObjectiveInfo &info = iter->second;
180 KALDI_LOG <<
"Overall accuracy for '" << name <<
"' is " 182 <<
", over " << info.
tot_weight <<
" frames.";
191 accuracy_vec(
j) = -1.0;
195 KALDI_LOG <<
"Overall per-dim accuracy vector for '" << name
196 <<
"' is " << accuracy_vec <<
" per frame" 197 <<
", over " << info.
tot_weight <<
" frames.";
213 num_cols = nnet_output.
NumCols();
215 supervision.
NumCols() == num_cols);
217 if (tot_accuracy_vec || tot_weight_vec)
219 tot_accuracy_vec->
Dim() == num_cols &&
220 tot_weight_vec->
Dim() == num_cols);
221 if (tot_accuracy_vec) tot_accuracy_vec->
Set(0.0);
222 if (tot_weight_vec) tot_weight_vec->
Set(0.0);
226 std::vector<int32> best_index_cpu;
231 double tot_weight = 0.0,
236 switch (supervision.
Type()) {
240 for (
int32 r = 0; r < num_rows; r++) {
244 vec.
Max(&best_index);
245 tot_weight += row_sum;
247 (*tot_weight_vec)(best_index) += row_sum;
248 if (best_index == best_index_cpu[r]) {
249 tot_accuracy += row_sum;
250 if (tot_accuracy_vec)
251 (*tot_accuracy_vec)(best_index) += row_sum;
258 for (
int32 r = 0; r < num_rows; r++) {
262 vec.
Max(&best_index);
263 tot_weight += row_sum;
265 (*tot_weight_vec)(best_index) += row_sum;
266 if (best_index == best_index_cpu[r]) {
267 tot_accuracy += row_sum;
268 if (tot_accuracy_vec)
269 (*tot_accuracy_vec)(best_index) += row_sum;
276 for (
int32 r = 0; r < num_rows; r++) {
280 row.
Max(&best_index);
282 tot_weight += row_sum;
284 (*tot_weight_vec)(best_index) += row_sum;
285 if (best_index == best_index_cpu[r]) {
286 tot_accuracy += row_sum;
287 if (tot_accuracy_vec)
288 (*tot_accuracy_vec)(best_index) += row_sum;
293 default:
KALDI_ERR <<
"Bad general-matrix type.";
295 *tot_weight_out = tot_weight;
296 *tot_accuracy_out = tot_accuracy;
300 const std::string &output_name)
const {
301 unordered_map<std::string, SimpleObjectiveInfo, StringHasher>::const_iterator
304 return &(iter->second);
310 double tot_objectives = 0.0;
311 double tot_weight = 0.0;
312 unordered_map<std::string, SimpleObjectiveInfo, StringHasher>::const_iterator
314 for (; iter != end; ++iter) {
315 tot_objectives += iter->second.tot_objective;
316 tot_weight += iter->second.tot_weight;
319 if (total_weight) *total_weight = tot_weight;
320 return tot_objectives;
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
unordered_map< std::string, PerDimObjectiveInfo, StringHasher > accuracy_info_
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
const Nnet & GetDeriv() const
Real Max(int32 *index) const
void CopyToVec(std::vector< T > *dst) const
This function resizes *dst if needed.
void GetMatrix(Matrix< BaseFloat > *mat) const
Outputs the contents as a matrix.
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
void Compute(const NnetExample &eg)
NnetComputeProbOptions config_
NnetComputeProb(const NnetComputeProbOptions &config, const Nnet &nnet)
Vector< BaseFloat > tot_weight_vec
GeneralMatrix features
The features or labels.
const Matrix< BaseFloat > & GetFullMatrix() const
Returns the contents as a Matrix<BaseFloat>.
void ProcessOutputs(const NnetExample &eg, NnetComputer *computer)
bool store_component_stats
ObjectiveType objective_type
A hashing function object for strings.
bool PrintTotalStats() const
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > objf_info_
This file contains some miscellaneous functions dealing with class Nnet.
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
MatrixIndexT NumCols() const
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase< BaseFloat > &nnet_output, BaseFloat *tot_weight_out, BaseFloat *tot_accuracy_out, VectorBase< BaseFloat > *tot_weight_vec, VectorBase< BaseFloat > *tot_accuracy_vec)
This function computes the frame accuracy for this minibatch.
Vector< BaseFloat > tot_objective_vec
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
GeneralMatrixType Type() const
Returns the type of the matrix: kSparseMatrix, kCompressedMatrix or kFullMatrix.
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
int32 num_minibatches_processed_
Real Max() const
Returns the maximum value of any element, or -infinity for the empty vector.
bool compute_per_dim_accuracy
MatrixIndexT Dim() const
Returns the dimension of the vector.
const SimpleObjectiveInfo * GetObjective(const std::string &output_name) const
Real Sum() const
Returns sum of the elements.
void FindRowMaxId(CuArray< int32 > *id) const
Find the id of the maximal element for each row (resizes the 'id' array to the appropriate size)...
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it's for the con...
NnetComputeOptions compute_config
Matrix for CUDA computing.
MatrixIndexT NumCols() const
CachingOptimizingCompiler compiler_
A class representing a vector.
class NnetComputer is responsible for executing the computation described in the "computation" object...
#define KALDI_ASSERT(cond)
MatrixIndexT NumRows() const
void Set(Real f)
Set all members of a vector to a specified value.
double GetTotalObjective(double *tot_weight) const
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
const SparseMatrix< BaseFloat > & GetSparseMatrix() const
Returns the contents as a SparseMatrix.
MatrixIndexT NumRows() const
Dimensions.
Provides a vector abstraction class.
std::vector< NnetIo > io
"io" contains the input and output.
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...
const SparseVector< Real > & Row(MatrixIndexT r) const
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...