nnet-chain-diagnostics.h
Go to the documentation of this file.
1 // nnet3/nnet-chain-diagnostics.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_CHAIN_DIAGNOSTICS_H_
21 #define KALDI_NNET3_NNET_CHAIN_DIAGNOSTICS_H_
22 
23 #include "nnet3/nnet-example.h"
24 #include "nnet3/nnet-computation.h"
25 #include "nnet3/nnet-compute.h"
26 #include "nnet3/nnet-optimize.h"
28 #include "nnet3/nnet-diagnostics.h"
29 #include "chain/chain-training.h"
30 #include "chain/chain-den-graph.h"
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
35 
37  double tot_weight;
38  double tot_like;
39  double tot_l2_term;
40  ChainObjectiveInfo(): tot_weight(0.0),
41  tot_like(0.0),
42  tot_l2_term(0.0) { }
43 };
44 
45 
55  public:
56  // does not store a reference to 'config' but does store one to 'nnet'.
58  const chain::ChainTrainingOptions &chain_config,
59  const fst::StdVectorFst &den_fst,
60  const Nnet &nnet);
61 
62  // This version of the constructor may only be called if
63  // nnet_config.store_component_stats == true and nnet_config.compute_deriv ==
64  // false; it means it will store the component stats in 'nnet'. In this case
65  // you should call ZeroComponentStats(nnet) first if you want the stats to be
66  // zeroed first.
68  const chain::ChainTrainingOptions &chain_config,
69  const fst::StdVectorFst &den_fst,
70  Nnet *nnet);
71 
72 
73  // Reset the likelihood stats, and the derivative stats (if computed).
74  void Reset();
75 
76  // compute objective on one minibatch.
77  void Compute(const NnetChainExample &chain_eg);
78 
79  // Prints out the final stats, and return true if there was a nonzero count.
80  bool PrintTotalStats() const;
81 
82  // returns the objective-function info for this output name (e.g. "output"),
83  // or NULL if there is no such info.
84  const ChainObjectiveInfo *GetObjective(const std::string &output_name) const;
85 
86  // This function returns the total objective over all output nodes recorded here, and
87  // outputs to 'tot_weight' the total weight (typically the number of frames)
88  // corresponding to it.
89  double GetTotalObjective(double *tot_weight) const;
90 
91  // if config.compute_deriv == true, returns a reference to the
92  // computed derivative. Otherwise crashes.
93  const Nnet &GetDeriv() const;
94 
96  private:
97  void ProcessOutputs(const NnetChainExample &chain_eg,
98  NnetComputer *computer);
99 
101  chain::ChainTrainingOptions chain_config_;
102  chain::DenominatorGraph den_graph_;
103  const Nnet &nnet_;
107  int32 num_minibatches_processed_; // this is only for diagnostics
108 
109  unordered_map<std::string, ChainObjectiveInfo, StringHasher> objf_info_;
110 
111 };
112 
117 void RecomputeStats(const std::vector<NnetChainExample> &egs,
118  const chain::ChainTrainingOptions &chain_config,
119  const fst::StdVectorFst &den_fst,
120  Nnet *nnet);
121 
122 
123 
124 } // namespace nnet3
125 } // namespace kaldi
126 
127 #endif // KALDI_NNET3_NNET_CHAIN_DIAGNOSTICS_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
chain::ChainTrainingOptions chain_config_
This class is for computing objective-function values in a nnet3+chain setup, for diagnostics...
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
kaldi::int32 int32
unordered_map< std::string, ChainObjectiveInfo, StringHasher > objf_info_
fst::StdVectorFst StdVectorFst
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59