#include <nnet-loss.h>


Public Member Functions | |
| Mse (LossOptions &opts) | |
| ~Mse () | |
| void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff) |
| Evaluate mean square error using target-matrix,. More... | |
| void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const Posterior &target, CuMatrix< BaseFloat > *diff) |
| Evaluate mean square error using target-posteior,. More... | |
| std::string | Report () |
| Generate string with error report. More... | |
| BaseFloat | AvgLoss () |
| Get loss value (frame average),. More... | |
Public Member Functions inherited from LossItf | |
| LossItf (LossOptions &opts) | |
| virtual | ~LossItf () |
Private Attributes | |
| double | frames_ |
| double | loss_ |
| double | frames_progress_ |
| double | loss_progress_ |
| std::vector< float > | loss_vec_ |
| CuVector< BaseFloat > | frame_weights_ |
| CuMatrix< BaseFloat > | tgt_mat_ |
| CuMatrix< BaseFloat > | diff_pow_2_ |
Additional Inherited Members | |
Protected Attributes inherited from LossItf | |
| LossOptions | opts_ |
| Timer | timer_ |
Definition at line 149 of file nnet-loss.h.
|
inline |
Definition at line 151 of file nnet-loss.h.
|
inline |
Definition at line 159 of file nnet-loss.h.
|
inlinevirtual |
|
virtual |
Evaluate mean square error using target-matrix,.
Implements LossItf.
Definition at line 228 of file nnet-loss.cc.
References CuMatrixBase< Real >::AddMat(), VectorBase< Real >::Dim(), Xent::frame_weights_, Xent::frames_, Xent::frames_progress_, KALDI_ASSERT, KALDI_ISFINITE, KALDI_LOG, LossOptions::loss_report_frames, Xent::loss_vec_, CuMatrixBase< Real >::MulElements(), CuMatrixBase< Real >::MulRowsVec(), CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), LossItf::opts_, VectorBase< Real >::Sum(), and CuMatrixBase< Real >::Sum().
Referenced by main().
|
virtual |
Evaluate mean square error using target-posteior,.
Implements LossItf.
Definition at line 283 of file nnet-loss.cc.
References Xent::Eval(), KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), kaldi::nnet1::PosteriorToMatrix(), and Xent::tgt_mat_.
|
virtual |
Generate string with error report.
Implements LossItf.
Definition at line 299 of file nnet-loss.cc.
References Xent::frames_, and Xent::loss_vec_.
Referenced by main().
Definition at line 193 of file nnet-loss.h.
Definition at line 191 of file nnet-loss.h.
|
private |
Definition at line 184 of file nnet-loss.h.
|
private |
Definition at line 187 of file nnet-loss.h.
|
private |
Definition at line 185 of file nnet-loss.h.
|
private |
Definition at line 188 of file nnet-loss.h.
|
private |
Definition at line 189 of file nnet-loss.h.
Definition at line 192 of file nnet-loss.h.