All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
NnetStats Class Reference

#include <nnet-stats.h>

Collaboration diagram for NnetStats:

Classes

struct  StatsElement
 

Public Member Functions

 NnetStats (int32 affine_component_index, BaseFloat bucket_width)
 
void AddStats (BaseFloat avg_deriv, BaseFloat avg_value)
 
void AddStatsFromNnet (const Nnet &nnet)
 
void PrintStats (std::ostream &os)
 

Private Member Functions

int32 BucketFor (BaseFloat avg_deriv)
 

Private Attributes

int32 affine_component_index_
 
BaseFloat bucket_width_
 
std::vector< StatsElementbuckets_
 
StatsElement global_
 

Detailed Description

Definition at line 44 of file nnet-stats.h.

Constructor & Destructor Documentation

NnetStats ( int32  affine_component_index,
BaseFloat  bucket_width 
)
inline

Definition at line 46 of file nnet-stats.h.

Referenced by kaldi::nnet2::GetNnetStats().

46  :
47  affine_component_index_(affine_component_index),
48  bucket_width_(bucket_width), global_(0, -1) { }
StatsElement global_
Definition: nnet-stats.h:85

Member Function Documentation

void AddStats ( BaseFloat  avg_deriv,
BaseFloat  avg_value 
)

Definition at line 58 of file nnet-stats.cc.

References NnetStats::StatsElement::AddStats(), NnetStats::BucketFor(), NnetStats::buckets_, and NnetStats::global_.

Referenced by NnetStats::AddStatsFromNnet().

58  {
59  global_.AddStats(avg_deriv, avg_value);
60  buckets_[BucketFor(avg_deriv)].AddStats(avg_deriv, avg_value);
61 }
std::vector< StatsElement > buckets_
Definition: nnet-stats.h:84
StatsElement global_
Definition: nnet-stats.h:85
int32 BucketFor(BaseFloat avg_deriv)
Definition: nnet-stats.cc:47
void AddStats(BaseFloat avg_deriv, BaseFloat avg_value)
Definition: nnet-stats.cc:39
void AddStatsFromNnet ( const Nnet nnet)

Definition at line 63 of file nnet-stats.cc.

References NnetStats::AddStats(), NnetStats::affine_component_index_, count, NonlinearComponent::Count(), NonlinearComponent::DerivSum(), CuVectorBase< Real >::Dim(), Nnet::GetComponent(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, KALDI_WARN, and NonlinearComponent::ValueSum().

63  {
64  const AffineComponent *ac = dynamic_cast<const AffineComponent*>(
65  &(nnet.GetComponent(affine_component_index_)));
66  KALDI_ASSERT(ac != NULL); // would be an error in calling code.
67  const NonlinearComponent *nc = dynamic_cast<const NonlinearComponent*>(
68  &(nnet.GetComponent(affine_component_index_ + 1)));
69  KALDI_ASSERT(nc != NULL); // would be an error in calling code.
70 
71  double count = nc->Count();
72  if (count == 0) {
73  KALDI_WARN << "No stats stored with nonlinear component";
74  return;
75  }
76  const CuVector<double> &value_sum = nc->ValueSum();
77  const CuVector<double> &deriv_sum = nc->DerivSum();
78  if (value_sum.Dim() != deriv_sum.Dim())
79  KALDI_ERR << "Error computing nnet stats: probably you are "
80  << "trying to compute stats for a sigmoid layer.";
81  for (int32 i = 0; i < value_sum.Dim(); i++) {
82  BaseFloat avg_value = value_sum(i) / count,
83  avg_deriv = deriv_sum(i) / count;
84  AddStats(avg_deriv, avg_value);
85  }
86 }
void AddStats(BaseFloat avg_deriv, BaseFloat avg_value)
Definition: nnet-stats.cc:58
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ERR
Definition: kaldi-error.h:127
#define KALDI_WARN
Definition: kaldi-error.h:130
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
int32 BucketFor ( BaseFloat  avg_deriv)
private

Definition at line 47 of file nnet-stats.cc.

References NnetStats::bucket_width_, NnetStats::buckets_, and KALDI_ASSERT.

Referenced by NnetStats::AddStats().

47  {
48  KALDI_ASSERT(avg_deriv >= 0.0);
50  // cast ratio to int. Since we do +0.5, this rounds down.
51  int32 index = static_cast<int32>(avg_deriv / bucket_width_ + 0.5);
52  while (index >= static_cast<int32>(buckets_.size()))
53  buckets_.push_back(StatsElement(buckets_.size() * bucket_width_,
54  (buckets_.size() + 1) * bucket_width_));
55  return index;
56 }
std::vector< StatsElement > buckets_
Definition: nnet-stats.h:84
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void PrintStats ( std::ostream &  os)

Definition at line 88 of file nnet-stats.cc.

References NnetStats::buckets_, NnetStats::global_, rnnlm::i, and NnetStats::StatsElement::PrintStats().

88  {
89  os << "Stats for buckets:" << std::endl;
90  for (size_t i = 0; i < buckets_.size(); i++) {
91  buckets_[i].PrintStats(os);
92  os << std::endl;
93  }
94  os << "Global stats: ";
95  global_.PrintStats(os);
96  os << std::endl;
97 }
std::vector< StatsElement > buckets_
Definition: nnet-stats.h:84
StatsElement global_
Definition: nnet-stats.h:85
void PrintStats(std::ostream &os)
Definition: nnet-stats.cc:25

Member Data Documentation

int32 affine_component_index_
private

Definition at line 80 of file nnet-stats.h.

Referenced by NnetStats::AddStatsFromNnet().

BaseFloat bucket_width_
private

Definition at line 82 of file nnet-stats.h.

Referenced by NnetStats::BucketFor().

std::vector<StatsElement> buckets_
private

Definition at line 84 of file nnet-stats.h.

Referenced by NnetStats::AddStats(), NnetStats::BucketFor(), and NnetStats::PrintStats().

StatsElement global_
private

Definition at line 85 of file nnet-stats.h.

Referenced by NnetStats::AddStats(), and NnetStats::PrintStats().


The documentation for this class was generated from the following files: