MultiTaskLoss Class Reference

#include <nnet-loss.h>

Inheritance diagram for MultiTaskLoss:
Collaboration diagram for MultiTaskLoss:

Public Member Functions

 MultiTaskLoss (LossOptions &opts)
 
 ~MultiTaskLoss ()
 
void InitFromString (const std::string &s)
 Initialize from string, the format for string 's' is : 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'. More...
 
void Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
 Evaluate mean square error using target-matrix,. More...
 
void Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const Posterior &target, CuMatrix< BaseFloat > *diff)
 Evaluate mean square error using target-posteior,. More...
 
std::string Report ()
 Generate string with error report. More...
 
BaseFloat AvgLoss ()
 Get loss value (frame average),. More...
 
- Public Member Functions inherited from LossItf
 LossItf (LossOptions &opts)
 
virtual ~LossItf ()
 

Private Attributes

std::vector< LossItf * > loss_vec_
 
std::vector< int32loss_dim_
 
std::vector< BaseFloatloss_weights_
 
std::vector< int32loss_dim_offset_
 
CuMatrix< BaseFloattgt_mat_
 

Additional Inherited Members

- Protected Attributes inherited from LossItf
LossOptions opts_
 
Timer timer_
 

Detailed Description

Definition at line 197 of file nnet-loss.h.

Constructor & Destructor Documentation

◆ MultiTaskLoss()

MultiTaskLoss ( LossOptions opts)
inline

Definition at line 199 of file nnet-loss.h.

199  :
200  LossItf(opts)
201  { }
LossItf(LossOptions &opts)
Definition: nnet-loss.h:53

◆ ~MultiTaskLoss()

~MultiTaskLoss ( )
inline

Definition at line 203 of file nnet-loss.h.

203  {
204  while (loss_vec_.size() > 0) {
205  delete loss_vec_.back();
206  loss_vec_.pop_back();
207  }
208  }
std::vector< LossItf * > loss_vec_
Definition: nnet-loss.h:238

Member Function Documentation

◆ AvgLoss()

BaseFloat AvgLoss ( )
virtual

Get loss value (frame average),.

Implements LossItf.

Definition at line 445 of file nnet-loss.cc.

References rnnlm::i, KALDI_ISFINITE, KALDI_WARN, and Xent::loss_vec_.

445  {
446  BaseFloat ans(0.0);
447  for (int32 i = 0; i < loss_vec_.size(); i++) {
448  BaseFloat val = loss_weights_[i] * loss_vec_[i]->AvgLoss();
449  if (!KALDI_ISFINITE(val)) {
450  KALDI_WARN << "Loss " << i+1 << ", has bad objective function value '"
451  << val << "', using 0.0 instead.";
452  val = 0.0;
453  }
454  ans += val;
455  }
456  return ans;
457 }
#define KALDI_ISFINITE(x)
Definition: kaldi-math.h:74
kaldi::int32 int32
std::vector< BaseFloat > loss_weights_
Definition: nnet-loss.h:240
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< LossItf * > loss_vec_
Definition: nnet-loss.h:238

◆ Eval() [1/2]

void Eval ( const VectorBase< BaseFloat > &  frame_weights,
const CuMatrixBase< BaseFloat > &  net_out,
const CuMatrixBase< BaseFloat > &  target,
CuMatrix< BaseFloat > *  diff 
)
inlinevirtual

Evaluate mean square error using target-matrix,.

Implements LossItf.

Definition at line 218 of file nnet-loss.h.

References KALDI_ERR.

Referenced by main().

221  {
222  KALDI_ERR << "This is not supposed to be called!";
223  }
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ Eval() [2/2]

void Eval ( const VectorBase< BaseFloat > &  frame_weights,
const CuMatrixBase< BaseFloat > &  net_out,
const Posterior target,
CuMatrix< BaseFloat > *  diff 
)
virtual

Evaluate mean square error using target-posteior,.

One vector of frame_weights per loss-function, The original frame weights are multiplied with a mask of `defined targets' according to the 'Posterior'.

Implements LossItf.

Definition at line 365 of file nnet-loss.cc.

References CuMatrixBase< Real >::ColRange(), KALDI_ASSERT, Xent::loss_vec_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), kaldi::nnet1::PosteriorToMatrix(), CuMatrix< Real >::Resize(), and Xent::tgt_mat_.

368  {
369  int32 num_frames = net_out.NumRows(),
370  num_output = net_out.NumCols();
371  KALDI_ASSERT(num_frames == post.size());
372  KALDI_ASSERT(num_output == loss_dim_offset_.back()); // sum of loss-dims,
373 
374  // convert posterior to matrix,
375  PosteriorToMatrix(post, num_output, &tgt_mat_);
376 
377  // allocate diff matrix,
378  diff->Resize(num_frames, num_output);
379 
383  std::vector<Vector<BaseFloat> > frmwei_have_tgt;
384  for (int32 l = 0; l < loss_vec_.size(); l++) {
385  // copy original weights,
386  frmwei_have_tgt.push_back(Vector<BaseFloat>(frame_weights));
387  // We need to mask-out the frames for which the 'posterior' is not defined (= is empty):
388  int32 loss_beg = loss_dim_offset_[l]; // first column of loss target,
389  int32 loss_end = loss_dim_offset_[l+1]; // (last+1) column of loss target,
390  for (int32 f = 0; f < num_frames; f++) {
391  bool tgt_defined = false;
392  for (int32 p = 0; p < post[f].size(); p++) {
393  if (post[f][p].first >= loss_beg && post[f][p].first < loss_end) {
394  tgt_defined = true;
395  break;
396  }
397  }
398  if (!tgt_defined) {
399  frmwei_have_tgt[l](f) = 0.0; // set zero_weight for the frame with no targets!
400  }
401  }
402  }
403 
404  // call the vector of loss functions,
405  CuMatrix<BaseFloat> diff_aux;
406  for (int32 l = 0; l < loss_vec_.size(); l++) {
407  loss_vec_[l]->Eval(frmwei_have_tgt[l],
408  net_out.ColRange(loss_dim_offset_[l], loss_dim_[l]),
409  tgt_mat_.ColRange(loss_dim_offset_[l], loss_dim_[l]),
410  &diff_aux);
411  // Scale the gradients,
412  diff_aux.Scale(loss_weights_[l]);
413  // Copy to diff,
414  diff->ColRange(loss_dim_offset_[l], loss_dim_[l]).CopyFromMat(diff_aux);
415  }
416 }
kaldi::int32 int32
void PosteriorToMatrix(const Posterior &post, const int32 post_dim, CuMatrix< Real > *mat)
Wrapper of PosteriorToMatrix with CuMatrix argument.
Definition: nnet-utils.h:292
std::vector< BaseFloat > loss_weights_
Definition: nnet-loss.h:240
CuMatrix< BaseFloat > tgt_mat_
Definition: nnet-loss.h:244
std::vector< int32 > loss_dim_offset_
Definition: nnet-loss.h:242
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< int32 > loss_dim_
Definition: nnet-loss.h:239
std::vector< LossItf * > loss_vec_
Definition: nnet-loss.h:238

◆ InitFromString()

void InitFromString ( const std::string &  s)

Initialize from string, the format for string 's' is : 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'.

Practically it can look like this : 'multitask,xent,2456,1.0,mse,440,0.001'

Definition at line 318 of file nnet-loss.cc.

References kaldi::ConvertStringToInteger(), kaldi::ConvertStringToReal(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, Xent::loss_vec_, LossItf::opts_, kaldi::SplitStringToVector(), and Xent::Xent().

Referenced by main().

318  {
319  std::vector<std::string> v;
320  SplitStringToVector(s, ",:" /* delimiter */, false, &v);
321 
322  KALDI_ASSERT((v.size()-1) % 3 == 0); // triplets,
323  KALDI_ASSERT(v[0] == "multitask"); // header,
324 
325  // parse the definition of multitask loss,
326  std::vector<std::string>::iterator it(v.begin()+1); // skip header,
327  for ( ; it != v.end(); ++it) {
328  // type,
329  if (*it == "xent") {
330  loss_vec_.push_back(new Xent(opts_));
331  } else if (*it == "mse") {
332  loss_vec_.push_back(new Mse(opts_));
333  } else {
334  KALDI_ERR << "Unknown objective function code : " << *it;
335  }
336  ++it;
337  // dim,
338  int32 dim;
339  if (!ConvertStringToInteger(*it, &dim)) {
340  KALDI_ERR << "Cannot convert 'dim' " << *it << " to integer!";
341  }
342  loss_dim_.push_back(dim);
343  ++it;
344  // weight,
345  BaseFloat weight;
346  if (!ConvertStringToReal(*it, &weight)) {
347  KALDI_ERR << "Cannot convert 'weight' " << *it << " to integer!";
348  }
349  KALDI_ASSERT(weight >= 0.0);
350  loss_weights_.push_back(weight);
351  }
352 
353  // build vector with starting-point offsets,
354  loss_dim_offset_.resize(loss_dim_.size()+1, 0); // 1st zero stays,
355  for (int32 i = 1; i <= loss_dim_.size(); i++) {
357  }
358 
359  // sanity check,
360  KALDI_ASSERT(loss_vec_.size() > 0);
361  KALDI_ASSERT(loss_vec_.size() == loss_dim_.size());
362  KALDI_ASSERT(loss_vec_.size() == loss_weights_.size());
363 }
bool ConvertStringToInteger(const std::string &str, Int *out)
Converts a string into an integer via strtoll and returns false if there was any kind of problem (i...
Definition: text-utils.h:118
kaldi::int32 int32
std::vector< BaseFloat > loss_weights_
Definition: nnet-loss.h:240
float BaseFloat
Definition: kaldi-types.h:29
std::vector< int32 > loss_dim_offset_
Definition: nnet-loss.h:242
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
#define KALDI_ERR
Definition: kaldi-error.h:147
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< int32 > loss_dim_
Definition: nnet-loss.h:239
std::vector< LossItf * > loss_vec_
Definition: nnet-loss.h:238
LossOptions opts_
Definition: nnet-loss.h:77

◆ Report()

std::string Report ( )
virtual

Generate string with error report.

Implements LossItf.

Definition at line 418 of file nnet-loss.cc.

References Xent::AvgLoss(), rnnlm::i, and Xent::loss_vec_.

Referenced by main().

418  {
419  // calculate overall loss (weighted),
420  BaseFloat overall_loss = AvgLoss();
421  // copy the loss-values into a vector,
422  std::vector<BaseFloat> loss_values;
423  for (int32 i = 0; i < loss_vec_.size(); i++) {
424  loss_values.push_back(loss_vec_[i]->AvgLoss());
425  }
426 
427  // build the message,
428  std::ostringstream oss;
429  oss << "MultiTaskLoss, with " << loss_vec_.size()
430  << " parallel loss functions." << std::endl;
431  // individual loss reports first,
432  for (int32 i = 0; i < loss_vec_.size(); i++) {
433  oss << "Loss " << i+1 << ", " << loss_vec_[i]->Report() << std::endl;
434  }
435 
436  // overall loss is last,
437  oss << "Loss (OVERALL), "
438  << "AvgLoss: " << overall_loss << " (MultiTaskLoss), "
439  << "weights " << loss_weights_ << ", "
440  << "values " << loss_values << std::endl;
441 
442  return oss.str();
443 }
kaldi::int32 int32
BaseFloat AvgLoss()
Get loss value (frame average),.
Definition: nnet-loss.cc:445
std::vector< BaseFloat > loss_weights_
Definition: nnet-loss.h:240
float BaseFloat
Definition: kaldi-types.h:29
std::vector< LossItf * > loss_vec_
Definition: nnet-loss.h:238

Member Data Documentation

◆ loss_dim_

std::vector<int32> loss_dim_
private

Definition at line 239 of file nnet-loss.h.

◆ loss_dim_offset_

std::vector<int32> loss_dim_offset_
private

Definition at line 242 of file nnet-loss.h.

◆ loss_vec_

std::vector<LossItf*> loss_vec_
private

Definition at line 238 of file nnet-loss.h.

◆ loss_weights_

std::vector<BaseFloat> loss_weights_
private

Definition at line 240 of file nnet-loss.h.

◆ tgt_mat_

CuMatrix<BaseFloat> tgt_mat_
private

Definition at line 244 of file nnet-loss.h.


The documentation for this class was generated from the following files: