#include <nnet-loss.h>
Public Member Functions | |
MultiTaskLoss (LossOptions &opts) | |
~MultiTaskLoss () | |
void | InitFromString (const std::string &s) |
Initialize from string, the format for string 's' is : 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'. More... | |
void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff) |
Evaluate mean square error using target-matrix,. More... | |
void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const Posterior &target, CuMatrix< BaseFloat > *diff) |
Evaluate mean square error using target-posteior,. More... | |
std::string | Report () |
Generate string with error report. More... | |
BaseFloat | AvgLoss () |
Get loss value (frame average),. More... | |
Public Member Functions inherited from LossItf | |
LossItf (LossOptions &opts) | |
virtual | ~LossItf () |
Private Attributes | |
std::vector< LossItf * > | loss_vec_ |
std::vector< int32 > | loss_dim_ |
std::vector< BaseFloat > | loss_weights_ |
std::vector< int32 > | loss_dim_offset_ |
CuMatrix< BaseFloat > | tgt_mat_ |
Additional Inherited Members | |
Protected Attributes inherited from LossItf | |
LossOptions | opts_ |
Timer | timer_ |
Definition at line 197 of file nnet-loss.h.
|
inline |
Definition at line 199 of file nnet-loss.h.
|
inline |
Definition at line 203 of file nnet-loss.h.
|
virtual |
Get loss value (frame average),.
Implements LossItf.
Definition at line 445 of file nnet-loss.cc.
References rnnlm::i, KALDI_ISFINITE, KALDI_WARN, and Xent::loss_vec_.
|
inlinevirtual |
Evaluate mean square error using target-matrix,.
Implements LossItf.
Definition at line 218 of file nnet-loss.h.
References KALDI_ERR.
Referenced by main().
|
virtual |
Evaluate mean square error using target-posteior,.
One vector of frame_weights per loss-function, The original frame weights are multiplied with a mask of `defined targets' according to the 'Posterior'.
Implements LossItf.
Definition at line 365 of file nnet-loss.cc.
References CuMatrixBase< Real >::ColRange(), KALDI_ASSERT, Xent::loss_vec_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), kaldi::nnet1::PosteriorToMatrix(), CuMatrix< Real >::Resize(), and Xent::tgt_mat_.
void InitFromString | ( | const std::string & | s | ) |
Initialize from string, the format for string 's' is : 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'.
Practically it can look like this : 'multitask,xent,2456,1.0,mse,440,0.001'
Definition at line 318 of file nnet-loss.cc.
References kaldi::ConvertStringToInteger(), kaldi::ConvertStringToReal(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, Xent::loss_vec_, LossItf::opts_, kaldi::SplitStringToVector(), and Xent::Xent().
Referenced by main().
|
virtual |
Generate string with error report.
Implements LossItf.
Definition at line 418 of file nnet-loss.cc.
References Xent::AvgLoss(), rnnlm::i, and Xent::loss_vec_.
Referenced by main().
|
private |
Definition at line 239 of file nnet-loss.h.
|
private |
Definition at line 242 of file nnet-loss.h.
|
private |
Definition at line 238 of file nnet-loss.h.
|
private |
Definition at line 240 of file nnet-loss.h.
Definition at line 244 of file nnet-loss.h.