All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
DiscriminativeObjectiveInfo Struct Reference

#include <discriminative-training.h>

Collaboration diagram for DiscriminativeObjectiveInfo:

Public Member Functions

void Print (const std::string &criterion, bool print_avg_gradients=false, bool print_avg_output=false) const
 
void PrintAll (const std::string &criterion) const
 
void PrintAvgGradientForPdf (int32 pdf_id) const
 
void Add (const DiscriminativeObjectiveInfo &other)
 
double TotalObjf (const std::string &criterion) const
 
bool AccumulateGradients () const
 
bool AccumulateOutput () const
 
 DiscriminativeObjectiveInfo ()
 
 DiscriminativeObjectiveInfo (int32 num_pdfs)
 
 DiscriminativeObjectiveInfo (const DiscriminativeOptions &opts)
 
void Reset ()
 
void Configure (const DiscriminativeOptions &opts)
 

Public Attributes

double tot_t
 
double tot_t_weighted
 
double tot_objf
 
double tot_num_count
 
double tot_den_count
 
double tot_num_objf
 
double tot_l2_term
 
bool accumulate_gradients
 
bool accumulate_output
 
int32 num_pdfs
 
CuVector< double > gradients
 
CuVector< double > output
 

Detailed Description

Definition at line 130 of file discriminative-training.h.

Constructor & Destructor Documentation

Definition at line 28 of file discriminative-training.cc.

28  {
29  std::memset(this, 0, sizeof(*this));
30 }

Member Function Documentation

void Add ( const DiscriminativeObjectiveInfo other)

Definition at line 560 of file discriminative-training.cc.

References DiscriminativeObjectiveInfo::AccumulateGradients(), DiscriminativeObjectiveInfo::AccumulateOutput(), CuVectorBase< Real >::AddVec(), DiscriminativeObjectiveInfo::gradients, DiscriminativeObjectiveInfo::output, DiscriminativeObjectiveInfo::tot_den_count, DiscriminativeObjectiveInfo::tot_l2_term, DiscriminativeObjectiveInfo::tot_num_count, DiscriminativeObjectiveInfo::tot_num_objf, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t, and DiscriminativeObjectiveInfo::tot_t_weighted.

Referenced by DiscriminativeComputation::Compute(), and DiscriminativeObjectiveFunctionInfo::UpdateStats().

560  {
561  tot_t += other.tot_t;
562  tot_t_weighted += other.tot_t_weighted;
563  tot_objf += other.tot_objf; // Actually tot_den_objf for mmi
564  tot_num_count += other.tot_num_count;
565  tot_den_count += other.tot_den_count;
566  tot_num_objf += other.tot_num_objf; // Only for mmi
567  tot_l2_term += other.tot_l2_term;
568 
569  if (AccumulateGradients()) {
570  gradients.AddVec(1.0, other.gradients);
571  }
572  if (AccumulateOutput()) {
573  output.AddVec(1.0, other.output);
574  }
575 }
void AddVec(Real alpha, const CuVectorBase< Real > &vec, Real beta=1.0)
Definition: cu-vector.cc:1126
void Print ( const std::string &  criterion,
bool  print_avg_gradients = false,
bool  print_avg_output = false 
) const

Definition at line 577 of file discriminative-training.cc.

References DiscriminativeObjectiveInfo::AccumulateGradients(), DiscriminativeObjectiveInfo::AccumulateOutput(), DiscriminativeObjectiveInfo::gradients, KALDI_LOG, KALDI_VLOG, DiscriminativeObjectiveInfo::output, VectorBase< Real >::Scale(), DiscriminativeObjectiveInfo::tot_den_count, DiscriminativeObjectiveInfo::tot_num_count, DiscriminativeObjectiveInfo::tot_num_objf, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t, and DiscriminativeObjectiveInfo::tot_t_weighted.

Referenced by DiscriminativeComputation::Compute(), and DiscriminativeObjectiveInfo::PrintAll().

579  {
580  if (criterion == "mmi") {
581  double num_objf = tot_num_objf / tot_t_weighted,
582  den_objf = tot_objf / tot_t_weighted;
583  double objf = num_objf - den_objf;
584 
585  double avg_post_per_frame = tot_num_count / tot_t_weighted;
586 
587  KALDI_LOG << "Number of frames is " << tot_t
588  << " (weighted: " << tot_t_weighted
589  << "), average (num or den) posterior per frame is "
590  << avg_post_per_frame;
591  KALDI_LOG << "MMI objective function is " << num_objf << " - "
592  << den_objf << " = " << objf << " per frame, over "
593  << tot_t_weighted << " frames.";
594  } else if (criterion == "mpfe") {
595  double avg_gradients = (tot_num_count + tot_den_count) / tot_t_weighted;
596  double objf = tot_objf / tot_t_weighted;
597  KALDI_LOG << "Average num+den count of MPFE stats is " << avg_gradients
598  << " per frame, over "
599  << tot_t_weighted << " frames";
600  KALDI_LOG << "MPFE objective function is " << objf
601  << " per frame, over " << tot_t_weighted << " frames.";
602  } else if (criterion == "smbr") {
603  double avg_gradients = (tot_num_count + tot_den_count) / tot_t_weighted;
604  double objf = tot_objf / tot_t_weighted;
605  KALDI_LOG << "Average num+den count of SMBR stats is " << avg_gradients
606  << " per frame, over "
607  << tot_t_weighted << " frames";
608  KALDI_LOG << "SMBR objective function is " << objf
609  << " per frame, over " << tot_t_weighted << " frames.";
610  }
611 
612  if (AccumulateGradients()) {
613  Vector<double> temp(gradients);
614  temp.Scale(1.0/tot_t_weighted);
615  if (print_avg_gradients) {
616  KALDI_LOG << "Vector of average gradients wrt output activations is: \n" << temp;
617  } else {
618  KALDI_VLOG(4) << "Vector of average gradients wrt output activations is: \n" << temp;
619  }
620  }
621  if (AccumulateOutput()) {
622  Vector<double> temp(output);
623  temp.Scale(1.0/tot_t_weighted);
624  if (print_avg_output) {
625  KALDI_LOG << "Average DNN output is: \n" << temp;
626  } else {
627  KALDI_VLOG(4) << "Average DNN output is: \n" << temp;
628  }
629  }
630 }
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
#define KALDI_LOG
Definition: kaldi-error.h:133
void PrintAll ( const std::string &  criterion) const
inline

Definition at line 169 of file discriminative-training.h.

References DiscriminativeObjectiveInfo::Print().

Referenced by DiscriminativeComputation::Compute(), and NnetDiscriminativeComputeObjf::PrintTotalStats().

169  {
170  Print(criterion, true, true);
171  }
void Print(const std::string &criterion, bool print_avg_gradients=false, bool print_avg_output=false) const
void PrintAvgGradientForPdf ( int32  pdf_id) const

Definition at line 632 of file discriminative-training.cc.

References CuVectorBase< Real >::Dim(), DiscriminativeObjectiveInfo::gradients, KALDI_LOG, and DiscriminativeObjectiveInfo::tot_t_weighted.

632  {
633  if (pdf_id < gradients.Dim() && pdf_id >= 0) {
634  KALDI_LOG << "Average gradient wrt output activations of pdf " << pdf_id
635  << " is " << gradients(pdf_id) / tot_t_weighted
636  << " per frame, over "
637  << tot_t_weighted << " frames";
638  }
639 }
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
#define KALDI_LOG
Definition: kaldi-error.h:133
void Reset ( )

Definition at line 53 of file discriminative-training.cc.

References DiscriminativeObjectiveInfo::gradients, DiscriminativeObjectiveInfo::output, CuVectorBase< Real >::SetZero(), DiscriminativeObjectiveInfo::tot_den_count, DiscriminativeObjectiveInfo::tot_l2_term, DiscriminativeObjectiveInfo::tot_num_count, DiscriminativeObjectiveInfo::tot_num_objf, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t, and DiscriminativeObjectiveInfo::tot_t_weighted.

Referenced by DiscriminativeComputation::Compute(), DiscriminativeObjectiveInfo::DiscriminativeObjectiveInfo(), and DiscriminativeObjectiveFunctionInfo::UpdateStats().

Member Data Documentation

bool accumulate_gradients

Definition at line 148 of file discriminative-training.h.

Referenced by DiscriminativeObjectiveInfo::Configure().

bool accumulate_output

Definition at line 151 of file discriminative-training.h.

Referenced by DiscriminativeObjectiveInfo::Configure().

int32 num_pdfs

Definition at line 155 of file discriminative-training.h.

Referenced by DiscriminativeObjectiveInfo::Configure().


The documentation for this struct was generated from the following files: