nnet-stats.h
Go to the documentation of this file.
1 // nnet2/nnet-stats.h
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET2_NNET_STATS_H_
21 #define KALDI_NNET2_NNET_STATS_H_
22 
23 #include "nnet2/nnet-nnet.h"
24 
25 namespace kaldi {
26 namespace nnet2 {
27 
28 /* This program computes various statistics from a neural net. These are
29  summaries of certain quantities already present in the network as
30  stored on disk, especially regarding certain average values and
31  derivatives of the sigmoids.
32 */
33 
34 struct NnetStatsConfig {
36  NnetStatsConfig(): bucket_width(0.025) { }
37 
38  void Register(OptionsItf *opts) {
39  opts->Register("bucket-width", &bucket_width, "Width of bucket in average-derivative "
40  "stats for analysis.");
41  }
42 };
43 
44 class NnetStats {
45  public:
46  NnetStats(int32 affine_component_index, BaseFloat bucket_width):
47  affine_component_index_(affine_component_index),
48  bucket_width_(bucket_width), global_(0, -1) { }
49 
50  // Use default copy constructor and assignment operator.
51 
52  void AddStats(BaseFloat avg_deriv, BaseFloat avg_value);
53 
54  void AddStatsFromNnet(const Nnet &nnet);
55 
56  void PrintStats(std::ostream &os);
57  private:
58 
59  struct StatsElement {
60  BaseFloat deriv_begin; // avg-deriv, beginning of bucket.
61  BaseFloat deriv_end; // avg-deriv, end of bucket.
62  BaseFloat deriv_sum; // sum of avg-deriv within bucket.
63  BaseFloat deriv_sumsq; // Sum-squared of avg-deriv within bucket.
64  BaseFloat abs_value_sum; // Sum of abs(avg-value). Tells us whether it's
65  // saturating at one or both ends.
66  BaseFloat abs_value_sumsq; // Sum-squared of abs(avg-value).
67  int32 count; // Number of nonlinearities in this bucket.
68 
69  StatsElement(BaseFloat deriv_begin,
70  BaseFloat deriv_end):
71  deriv_begin(deriv_begin), deriv_end(deriv_end), deriv_sum(0.0),
72  deriv_sumsq(0.0), abs_value_sum(0.0), abs_value_sumsq(0.0), count(0) { }
73  void AddStats(BaseFloat avg_deriv, BaseFloat avg_value);
74  // Outputs stats for this bucket; no newline
75  void PrintStats(std::ostream &os);
76  };
77  int32 BucketFor(BaseFloat avg_deriv); // returns the bucket
78  // for this avg-derivative value, and makes sure it is allocated.
79 
80  int32 affine_component_index_; // Component index of the affine component
81  // associated with this nonlinearity.
82  BaseFloat bucket_width_; // width of buckets of stats we store (in derivative values).
83 
84  std::vector<StatsElement> buckets_; // Stats divided into buckets by avg_deriv.
85  StatsElement global_; // All the stats.
86 
87 };
88 
89 void GetNnetStats(const NnetStatsConfig &config,
90  const Nnet &nnet,
91  std::vector<NnetStats> *stats);
92 
93 
94 } // namespace nnet2
95 } // namespace kaldi
96 
97 #endif // KALDI_NNET2_NNET_STATS_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts)
Definition: nnet-stats.h:38
std::vector< StatsElement > buckets_
Definition: nnet-stats.h:84
kaldi::int32 int32
StatsElement global_
Definition: nnet-stats.h:85
void GetNnetStats(const NnetStatsConfig &config, const Nnet &nnet, std::vector< NnetStats > *stats)
Definition: nnet-stats.cc:99
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
StatsElement(BaseFloat deriv_begin, BaseFloat deriv_end)
Definition: nnet-stats.h:69
NnetStats(int32 affine_component_index, BaseFloat bucket_width)
Definition: nnet-stats.h:46