NnetDiscriminativeStats Struct Reference

#include <nnet-compute-discriminative.h>

Collaboration diagram for NnetDiscriminativeStats:

Public Member Functions

 NnetDiscriminativeStats ()
 
void Print (std::string criterion)
 
void Add (const NnetDiscriminativeStats &other)
 

Public Attributes

double tot_t
 
double tot_t_weighted
 
double tot_num_count
 
double tot_num_objf
 
double tot_den_objf
 

Detailed Description

Definition at line 70 of file nnet-compute-discriminative.h.

Constructor & Destructor Documentation

◆ NnetDiscriminativeStats()

Definition at line 80 of file nnet-compute-discriminative.h.

References NnetDiscriminativeUpdateOptions::criterion, kaldi::nnet2::NnetDiscriminativeUpdate(), and fst::Print().

80 { std::memset(this, 0, sizeof(*this)); }

Member Function Documentation

◆ Add()

◆ Print()

void Print ( std::string  criterion)

Definition at line 384 of file nnet-compute-discriminative.cc.

References KALDI_ASSERT, and KALDI_LOG.

Referenced by main(), and kaldi::nnet2::NnetDiscriminativeUpdateParallel().

384  {
385  KALDI_ASSERT(criterion == "mmi" || criterion == "smbr" ||
386  criterion == "mpfe");
387 
388  double avg_post_per_frame = tot_num_count / tot_t_weighted;
389  KALDI_LOG << "Number of frames is " << tot_t
390  << " (weighted: " << tot_t_weighted
391  << "), average (num or den) posterior per frame is "
392  << avg_post_per_frame;
393 
394  if (criterion == "mmi") {
395  double num_objf = tot_num_objf / tot_t_weighted,
396  den_objf = tot_den_objf / tot_t_weighted,
397  objf = num_objf - den_objf;
398  KALDI_LOG << "MMI objective function is " << num_objf << " - "
399  << den_objf << " = " << objf << " per frame, over "
400  << tot_t_weighted << " frames.";
401  } else if (criterion == "mpfe") {
402  double objf = tot_den_objf / tot_t_weighted; // this contains the actual
403  // summed objf
404  KALDI_LOG << "MPFE objective function is " << objf
405  << " per frame, over " << tot_t_weighted << " frames.";
406  } else {
407  double objf = tot_den_objf / tot_t_weighted; // this contains the actual
408  // summed objf
409  KALDI_LOG << "SMBR objective function is " << objf
410  << " per frame, over " << tot_t_weighted << " frames.";
411  }
412 }
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153

Member Data Documentation

◆ tot_den_objf

double tot_den_objf

◆ tot_num_count

double tot_num_count

◆ tot_num_objf

double tot_num_objf

◆ tot_t

◆ tot_t_weighted

double tot_t_weighted

The documentation for this struct was generated from the following files: