nnet-diagnostics.h
Go to the documentation of this file.
1 // nnet3/nnet-diagnostics.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_DIAGNOSTICS_H_
21 #define KALDI_NNET3_NNET_DIAGNOSTICS_H_
22 
23 #include "nnet3/nnet-example.h"
24 #include "nnet3/nnet-computation.h"
25 #include "nnet3/nnet-compute.h"
26 #include "nnet3/nnet-optimize.h"
28 #include "nnet3/nnet-training.h"
29 
30 namespace kaldi {
31 namespace nnet3 {
32 
33 
35  double tot_weight;
36  double tot_objective;
37  SimpleObjectiveInfo(): tot_weight(0.0),
38  tot_objective(0.0) { }
39 };
40 
41 /* This is used to store more detailed information about the objective,
42  * which will be used to compute accuracy per dimension.
43  * This might be sensible only for classification tasks.
44  */
46  // Counts for each of the classes in the output.
47  // In the simplest cases, this might be the number of frames for each class.
49 
50  // Objective contribution per-class
52 };
53 
54 
59  // note: the component stats, if stored, will be stored in the derivative nnet
60  // (c.f. GetDeriv()) if compute_deriv is true; otherwise, you should use the
61  // constructor of NnetComputeProb that takes a pointer to the nnet, and the
62  // stats will be stored there.
64 
66 
71  debug_computation(false),
72  compute_deriv(false),
73  compute_accuracy(true),
74  store_component_stats(false),
75  compute_per_dim_accuracy(false) { }
76  void Register(OptionsItf *opts) {
77  // compute_deriv is not included in the command line options
78  // because it's not relevant for nnet3-compute-prob.
79  // store_component_stats is not included in the command line
80  // options because it's not relevant for nnet3-compute-prob.
81  opts->Register("debug-computation", &debug_computation, "If true, turn on "
82  "debug for the actual computation (very verbose!)");
83  opts->Register("compute-accuracy", &compute_accuracy, "If true, compute "
84  "accuracy values as well as objective functions");
85  opts->Register("compute-per-dim-accuracy", &compute_per_dim_accuracy,
86  "If true, compute accuracy values per-dim");
87 
88  // register the optimization options with the prefix "optimization".
89  ParseOptions optimization_opts("optimization", opts);
90  optimize_config.Register(&optimization_opts);
91  // register the compiler options with the prefix "compiler".
92  ParseOptions compiler_opts("compiler", opts);
93  compiler_config.Register(&compiler_opts);
94  // register the compute options with the prefix "computation".
95  ParseOptions compute_opts("computation", opts);
96  compute_config.Register(&compute_opts);
97  }
98 };
99 
100 
108  public:
109  // does not store a reference to 'config' but does store one to 'nnet'.
111  const Nnet &nnet);
112 
113  // This version of the constructor may only be called if
114  // config.store_component_stats == true and config.compute_deriv == false;
115  // it means it will store the component stats in 'nnet'. In this
116  // case you should call ZeroComponentStats(nnet) first if you want
117  // the stats to be zeroed first.
119  Nnet *nnet);
120 
121 
122  // Reset the likelihood stats, and the derivative stats (if computed).
123  void Reset();
124 
125  // compute objective on one minibatch.
126  void Compute(const NnetExample &eg);
127 
128  // Prints out the final stats, and return true if there was a nonzero count.
129  bool PrintTotalStats() const;
130 
131  // returns the objective-function info for this output name (e.g. "output"),
132  // or NULL if there is no such info.
133  const SimpleObjectiveInfo *GetObjective(const std::string &output_name) const;
134 
135  // This function returns the total objective over all output nodes recorded here, and
136  // outputs to 'tot_weight' the total weight (typically the number of frames)
137  // corresponding to it.
138  double GetTotalObjective(double *tot_weight) const;
139 
140  // if config.compute_deriv == true, returns a reference to the
141  // computed derivative. Otherwise crashes.
142  const Nnet &GetDeriv() const;
143 
144  ~NnetComputeProb();
145  private:
146  void ProcessOutputs(const NnetExample &eg,
147  NnetComputer *computer);
148 
150  const Nnet &nnet_;
151 
155 
156  // this is only for diagnostics.
158 
159  unordered_map<std::string, SimpleObjectiveInfo, StringHasher> objf_info_;
160 
161  unordered_map<std::string, PerDimObjectiveInfo, StringHasher> accuracy_info_;
162 };
163 
164 
204 void ComputeAccuracy(const GeneralMatrix &supervision,
205  const CuMatrixBase<BaseFloat> &nnet_output,
207  BaseFloat *tot_accuracy,
208  VectorBase<BaseFloat> *tot_weight_vec = NULL,
209  VectorBase<BaseFloat> *tot_accuracy_vec = NULL);
210 
211 
212 } // namespace nnet3
213 } // namespace kaldi
214 
215 #endif // KALDI_NNET3_NNET_DIAGNOSTICS_H_
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: nnet-optimize.h:84
unordered_map< std::string, PerDimObjectiveInfo, StringHasher > accuracy_info_
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
NnetComputeProbOptions config_
kaldi::int32 int32
This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics.
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
unordered_map< std::string, SimpleObjectiveInfo, StringHasher > objf_info_
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase< BaseFloat > &nnet_output, BaseFloat *tot_weight_out, BaseFloat *tot_accuracy_out, VectorBase< BaseFloat > *tot_weight_vec, VectorBase< BaseFloat > *tot_accuracy_vec)
This function computes the frame accuracy for this minibatch.
Vector< BaseFloat > tot_objective_vec
void Register(OptionsItf *opts)
Definition: nnet-compute.h:42
Matrix for CUDA computing.
Definition: matrix-common.h:69
CachingOptimizingCompiler compiler_
A class representing a vector.
Definition: kaldi-vector.h:406
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
CachingOptimizingCompilerOptions compiler_config
Provides a vector abstraction class.
Definition: kaldi-vector.h:41