#include <nnet-loss.h>
Public Member Functions | |
Mse (LossOptions &opts) | |
~Mse () | |
void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff) |
Evaluate mean square error using target-matrix,. More... | |
void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const Posterior &target, CuMatrix< BaseFloat > *diff) |
Evaluate mean square error using target-posteior,. More... | |
std::string | Report () |
Generate string with error report. More... | |
BaseFloat | AvgLoss () |
Get loss value (frame average),. More... | |
Public Member Functions inherited from LossItf | |
LossItf (LossOptions &opts) | |
virtual | ~LossItf () |
Private Attributes | |
double | frames_ |
double | loss_ |
double | frames_progress_ |
double | loss_progress_ |
std::vector< float > | loss_vec_ |
CuVector< BaseFloat > | frame_weights_ |
CuMatrix< BaseFloat > | tgt_mat_ |
CuMatrix< BaseFloat > | diff_pow_2_ |
Additional Inherited Members | |
Protected Attributes inherited from LossItf | |
LossOptions | opts_ |
Timer | timer_ |
Definition at line 149 of file nnet-loss.h.
|
inline |
Definition at line 151 of file nnet-loss.h.
|
inline |
Definition at line 159 of file nnet-loss.h.
|
inlinevirtual |
|
virtual |
Evaluate mean square error using target-matrix,.
Implements LossItf.
Definition at line 228 of file nnet-loss.cc.
References CuMatrixBase< Real >::AddMat(), VectorBase< Real >::Dim(), Xent::frame_weights_, Xent::frames_, Xent::frames_progress_, KALDI_ASSERT, KALDI_ISFINITE, KALDI_LOG, LossOptions::loss_report_frames, Xent::loss_vec_, CuMatrixBase< Real >::MulElements(), CuMatrixBase< Real >::MulRowsVec(), CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), LossItf::opts_, VectorBase< Real >::Sum(), and CuMatrixBase< Real >::Sum().
Referenced by main().
|
virtual |
Evaluate mean square error using target-posteior,.
Implements LossItf.
Definition at line 283 of file nnet-loss.cc.
References Xent::Eval(), KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), kaldi::nnet1::PosteriorToMatrix(), and Xent::tgt_mat_.
|
virtual |
Generate string with error report.
Implements LossItf.
Definition at line 299 of file nnet-loss.cc.
References Xent::frames_, and Xent::loss_vec_.
Referenced by main().
Definition at line 193 of file nnet-loss.h.
Definition at line 191 of file nnet-loss.h.
|
private |
Definition at line 184 of file nnet-loss.h.
|
private |
Definition at line 187 of file nnet-loss.h.
|
private |
Definition at line 185 of file nnet-loss.h.
|
private |
Definition at line 188 of file nnet-loss.h.
|
private |
Definition at line 189 of file nnet-loss.h.
Definition at line 192 of file nnet-loss.h.