nnet-discriminative-diagnostics.h
Go to the documentation of this file.
1 // nnet3/nnet-discriminative-diagnostics.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // Copyright 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_NNET_DISCRIMINATIVE_DIAGNOSTICS_H_
22 #define KALDI_NNET3_NNET_DISCRIMINATIVE_DIAGNOSTICS_H_
23 
24 #include "nnet3/nnet-example.h"
25 #include "nnet3/nnet-computation.h"
26 #include "nnet3/nnet-compute.h"
27 #include "nnet3/nnet-optimize.h"
29 #include "nnet3/nnet-diagnostics.h"
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
40  public:
41  // does not store a reference to 'config' but does store one to 'nnet'.
43  const discriminative::DiscriminativeOptions &discriminative_config,
44  const TransitionModel &tmodel,
45  const VectorBase<BaseFloat> &priors,
46  const Nnet &nnet);
47 
48  // Reset the likelihood stats, and the derivative stats (if computed).
49  void Reset();
50 
51  // compute objective on one minibatch.
52  void Compute(const NnetDiscriminativeExample &eg);
53 
54  // Prints out the final stats, and return true if there was a nonzero count.
55  bool PrintTotalStats() const;
56 
57  // returns the objective-function info for this output name (e.g. "output"),
58  // or NULL if there is no such info.
60  const std::string &output_name) const;
61 
62  // if config.compute_deriv == true, returns a reference to the
63  // computed derivative. Otherwise crashes.
64  const Nnet &GetDeriv() const;
65 
67  private:
69  NnetComputer *computer);
70 
72 
76  const Nnet &nnet_;
79  int32 num_minibatches_processed_; // this is only for diagnostics
80 
81  unordered_map<std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher> objf_info_;
82 };
83 
84 } // namespace nnet3
85 } // namespace kaldi
86 
87 #endif // KALDI_NNET3_NNET_DISCRIMINATIVE_DIAGNOSTICS_H_
88 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher > objf_info_
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
kaldi::int32 int32
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
const discriminative::DiscriminativeObjectiveInfo * GetObjective(const std::string &output_name) const
NnetDiscriminativeComputeObjf(const NnetComputeProbOptions &nnet_config, const discriminative::DiscriminativeOptions &discriminative_config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, const Nnet &nnet)
This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics.
discriminative::DiscriminativeOptions discriminative_config_
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
void Compute(const NnetDiscriminativeExample &eg)
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
void ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer)