ObjectiveFunctionInfo Struct Reference

#include <nnet-training.h>

Collaboration diagram for ObjectiveFunctionInfo:

Public Member Functions

 ObjectiveFunctionInfo ()
 
void UpdateStats (const std::string &output_name, int32 minibatches_per_phase, int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf=0.0)
 
void PrintStatsForThisPhase (const std::string &output_name, int32 minibatches_per_phase, int32 phase) const
 
bool PrintTotalStats (const std::string &output_name) const
 

Public Attributes

int32 current_phase
 
int32 minibatches_this_phase
 
double tot_weight
 
double tot_objf
 
double tot_aux_objf
 
double tot_weight_this_phase
 
double tot_objf_this_phase
 
double tot_aux_objf_this_phase
 

Detailed Description

Definition at line 123 of file nnet-training.h.

Constructor & Destructor Documentation

◆ ObjectiveFunctionInfo()

Member Function Documentation

◆ PrintStatsForThisPhase()

void PrintStatsForThisPhase ( const std::string &  output_name,
int32  minibatches_per_phase,
int32  phase 
) const

Definition at line 269 of file nnet-training.cc.

References KALDI_LOG.

272  {
273  int32 start_minibatch = current_phase * minibatches_per_phase,
274  end_minibatch = phase * minibatches_per_phase - 1;
275 
276  if (tot_aux_objf_this_phase == 0.0) {
277  if (minibatches_per_phase == minibatches_this_phase) {
278  KALDI_LOG << "Average objective function for '" << output_name
279  << "' for minibatches " << start_minibatch
280  << '-' << end_minibatch << " is "
281  << (tot_objf_this_phase / tot_weight_this_phase) << " over "
282  << tot_weight_this_phase << " frames.";
283  } else {
284  KALDI_LOG << "Average objective function for '" << output_name
285  << " using " << minibatches_this_phase
286  << " minibatches in minibatch range " << start_minibatch
287  << '-' << end_minibatch << " is "
288  << (tot_objf_this_phase / tot_weight_this_phase) << " over "
289  << tot_weight_this_phase << " frames.";
290  }
291  } else {
294  sum_objf = objf + aux_objf;
295  if (minibatches_per_phase == minibatches_this_phase) {
296  KALDI_LOG << "Average objective function for '" << output_name
297  << "' for minibatches " << start_minibatch
298  << '-' << end_minibatch << " is "
299  << objf << " + " << aux_objf << " = " << sum_objf
300  << " over " << tot_weight_this_phase << " frames.";
301  } else {
302  KALDI_LOG << "Average objective function for '" << output_name
303  << "' using " << minibatches_this_phase
304  << " minibatches in minibatch range " << start_minibatch
305  << '-' << end_minibatch << " is "
306  << objf << " + " << aux_objf << " = " << sum_objf
307  << " over " << tot_weight_this_phase << " frames.";
308  }
309  }
310 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ PrintTotalStats()

bool PrintTotalStats ( const std::string &  output_name) const

Definition at line 312 of file nnet-training.cc.

References KALDI_LOG.

Referenced by NnetChainTrainer::PrintTotalStats(), and NnetTrainer::PrintTotalStats().

312  {
313  BaseFloat objf = (tot_objf / tot_weight),
314  aux_objf = (tot_aux_objf / tot_weight),
315  sum_objf = objf + aux_objf;
316  if (tot_aux_objf == 0.0) {
317  KALDI_LOG << "Overall average objective function for '" << name << "' is "
318  << (tot_objf / tot_weight) << " over " << tot_weight << " frames.";
319  } else {
320  KALDI_LOG << "Overall average objective function for '" << name << "' is "
321  << objf << " + " << aux_objf << " = " << sum_objf
322  << " over " << tot_weight << " frames.";
323  }
324  KALDI_LOG << "[this line is to be parsed by a script:] "
325  << "log-prob-per-frame="
326  << objf;
327  return (tot_weight != 0.0);
328 }
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ UpdateStats()

void UpdateStats ( const std::string &  output_name,
int32  minibatches_per_phase,
int32  minibatch_counter,
BaseFloat  this_minibatch_weight,
BaseFloat  this_minibatch_tot_objf,
BaseFloat  this_minibatch_tot_aux_objf = 0.0 
)

Definition at line 242 of file nnet-training.cc.

References KALDI_ASSERT.

248  {
249  int32 phase = minibatch_counter / minibatches_per_phase;
250  if (phase != current_phase) {
251  KALDI_ASSERT(phase > current_phase);
252  PrintStatsForThisPhase(output_name, minibatches_per_phase,
253  phase);
254  current_phase = phase;
255  tot_weight_this_phase = 0.0;
256  tot_objf_this_phase = 0.0;
259  }
261  tot_weight_this_phase += this_minibatch_weight;
262  tot_objf_this_phase += this_minibatch_tot_objf;
263  tot_aux_objf_this_phase += this_minibatch_tot_aux_objf;
264  tot_weight += this_minibatch_weight;
265  tot_objf += this_minibatch_tot_objf;
266  tot_aux_objf += this_minibatch_tot_aux_objf;
267 }
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void PrintStatsForThisPhase(const std::string &output_name, int32 minibatches_per_phase, int32 phase) const

Member Data Documentation

◆ current_phase

int32 current_phase

Definition at line 124 of file nnet-training.h.

◆ minibatches_this_phase

int32 minibatches_this_phase

Definition at line 125 of file nnet-training.h.

◆ tot_aux_objf

double tot_aux_objf

Definition at line 130 of file nnet-training.h.

◆ tot_aux_objf_this_phase

double tot_aux_objf_this_phase

Definition at line 136 of file nnet-training.h.

◆ tot_objf

double tot_objf

Definition at line 129 of file nnet-training.h.

◆ tot_objf_this_phase

double tot_objf_this_phase

Definition at line 135 of file nnet-training.h.

◆ tot_weight

double tot_weight

Definition at line 128 of file nnet-training.h.

◆ tot_weight_this_phase

double tot_weight_this_phase

Definition at line 134 of file nnet-training.h.


The documentation for this struct was generated from the following files: