DiscriminativeObjectiveFunctionInfo Struct Reference

#include <nnet-discriminative-training.h>

Collaboration diagram for DiscriminativeObjectiveFunctionInfo:

Public Member Functions

 DiscriminativeObjectiveFunctionInfo ()
 
void UpdateStats (const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase, int32 minibatch_counter, discriminative::DiscriminativeObjectiveInfo stats)
 
void PrintStatsForThisPhase (const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase) const
 
bool PrintTotalStats (const std::string &output_name, const std::string &criterion) const
 

Public Attributes

int32 current_phase
 
discriminative::DiscriminativeObjectiveInfo stats
 
discriminative::DiscriminativeObjectiveInfo stats_this_phase
 

Detailed Description

Definition at line 55 of file nnet-discriminative-training.h.

Constructor & Destructor Documentation

◆ DiscriminativeObjectiveFunctionInfo()

Member Function Documentation

◆ PrintStatsForThisPhase()

void PrintStatsForThisPhase ( const std::string &  output_name,
const std::string &  criterion,
int32  minibatches_per_phase 
) const

Definition at line 225 of file nnet-discriminative-training.cc.

References KALDI_LOG.

228  {
229  int32 start_minibatch = current_phase * minibatches_per_phase,
230  end_minibatch = start_minibatch + minibatches_per_phase - 1;
231 
233  KALDI_LOG << "Average objective function for '" << output_name
234  << "' for minibatches " << start_minibatch
235  << '-' << end_minibatch << " is " << objf
236  << " over " << stats_this_phase.tot_t_weighted << " frames.";
237 }
kaldi::int32 int32
double TotalObjf(const std::string &criterion) const
float BaseFloat
Definition: kaldi-types.h:29
discriminative::DiscriminativeObjectiveInfo stats_this_phase
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ PrintTotalStats()

bool PrintTotalStats ( const std::string &  output_name,
const std::string &  criterion 
) const

Definition at line 239 of file nnet-discriminative-training.cc.

References KALDI_LOG.

Referenced by NnetDiscriminativeTrainer::PrintTotalStats().

240  {
241  BaseFloat objf = stats.TotalObjf(criterion) /stats.tot_t_weighted;
242 
243  double avg_gradients = (stats.tot_num_count + stats.tot_den_count) /
245  KALDI_LOG << "Average num+den count of stats is " << avg_gradients
246  << " per frame, over "
247  << stats.tot_t_weighted << " frames.";
248  if (stats.tot_l2_term != 0.0) {
249  KALDI_LOG << "Average l2 norm of output per frame is "
250  << (stats.tot_l2_term / stats.tot_t_weighted) << " over "
251  << stats.tot_t_weighted << " frames.";
252  }
253 
254 
255  KALDI_LOG << "Overall average objective function for '" << name << "' is "
256  << objf << " over " << stats.tot_t_weighted << " frames.";
257  KALDI_LOG << "[this line is to be parsed by a script:] "
258  << criterion << "-per-frame="
259  << objf;
260  return (stats.tot_t_weighted != 0.0);
261 }
double TotalObjf(const std::string &criterion) const
float BaseFloat
Definition: kaldi-types.h:29
discriminative::DiscriminativeObjectiveInfo stats
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ UpdateStats()

void UpdateStats ( const std::string &  output_name,
const std::string &  criterion,
int32  minibatches_per_phase,
int32  minibatch_counter,
discriminative::DiscriminativeObjectiveInfo  stats 
)

Definition at line 208 of file nnet-discriminative-training.cc.

References KALDI_ASSERT.

213  {
214  int32 phase = minibatch_counter / minibatches_per_phase;
215  if (phase != current_phase) {
216  KALDI_ASSERT(phase == current_phase + 1); // or doesn't really make sense.
217  PrintStatsForThisPhase(output_name, criterion, minibatches_per_phase);
218  current_phase = phase;
220  }
221  stats_this_phase.Add(this_minibatch_stats);
222  stats.Add(this_minibatch_stats);
223 }
kaldi::int32 int32
void Add(const DiscriminativeObjectiveInfo &other)
discriminative::DiscriminativeObjectiveInfo stats
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
discriminative::DiscriminativeObjectiveInfo stats_this_phase
void PrintStatsForThisPhase(const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase) const

Member Data Documentation

◆ current_phase

int32 current_phase

Definition at line 56 of file nnet-discriminative-training.h.

◆ stats

◆ stats_this_phase


The documentation for this struct was generated from the following files: