#include <nnet-loss.h>
Public Member Functions | |
Xent (LossOptions &opts) | |
~Xent () | |
void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff) |
Evaluate cross entropy using target-matrix (supports soft labels),. More... | |
void | Eval (const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const Posterior &target, CuMatrix< BaseFloat > *diff) |
Evaluate cross entropy using target-posteriors (supports soft labels),. More... | |
std::string | Report () |
Generate string with error report,. More... | |
std::string | ReportPerClass () |
Generate string with per-class error report,. More... | |
BaseFloat | AvgLoss () |
Get loss value (frame average),. More... | |
Public Member Functions inherited from LossItf | |
LossItf (LossOptions &opts) | |
virtual | ~LossItf () |
Private Attributes | |
CuVector< double > | frames_ |
Vector< double > | correct_ |
CuVector< double > | xentropy_ |
CuVector< double > | entropy_ |
double | frames_progress_ |
double | xentropy_progress_ |
double | entropy_progress_ |
std::vector< float > | loss_vec_ |
double | elapsed_seconds_ |
CuVector< BaseFloat > | frame_weights_ |
CuVector< BaseFloat > | target_sum_ |
CuMatrix< BaseFloat > | tgt_mat_ |
CuMatrix< BaseFloat > | frames_aux_ |
CuMatrix< BaseFloat > | xentropy_aux_ |
CuMatrix< BaseFloat > | entropy_aux_ |
CuArray< int32 > | max_id_out_ |
CuArray< int32 > | max_id_tgt_ |
Additional Inherited Members | |
Protected Attributes inherited from LossItf | |
LossOptions | opts_ |
Timer | timer_ |
Definition at line 82 of file nnet-loss.h.
|
inline |
|
inline |
Definition at line 92 of file nnet-loss.h.
|
inlinevirtual |
Get loss value (frame average),.
Implements LossItf.
Definition at line 114 of file nnet-loss.h.
Referenced by MultiTaskLoss::Report().
|
virtual |
Evaluate cross entropy using target-matrix (supports soft labels),.
Implements LossItf.
Definition at line 62 of file nnet-loss.cc.
References CuMatrixBase< Real >::AddMat(), CuVectorBase< Real >::AddRowSumMat(), Xent::correct_, kaldi::nnet1::CountCorrectFramesWeighted(), VectorBase< Real >::Dim(), CuVectorBase< Real >::Dim(), Timer::Elapsed(), Xent::elapsed_seconds_, Xent::entropy_, Xent::entropy_aux_, Xent::entropy_progress_, CuMatrixBase< Real >::FindRowMaxId(), Xent::frame_weights_, Xent::frames_, Xent::frames_aux_, Xent::frames_progress_, KALDI_ASSERT, KALDI_ISFINITE, KALDI_LOG, LossOptions::loss_report_frames, Xent::loss_vec_, Xent::max_id_out_, Xent::max_id_tgt_, CuMatrixBase< Real >::MulRowsVec(), CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), LossItf::opts_, CuVector< Real >::Resize(), Vector< Real >::Resize(), CuVectorBase< Real >::Sum(), VectorBase< Real >::Sum(), CuMatrixBase< Real >::Sum(), Xent::target_sum_, LossItf::timer_, Xent::xentropy_, Xent::xentropy_aux_, and Xent::xentropy_progress_.
Referenced by Xent::Eval(), Mse::Eval(), and main().
|
virtual |
Evaluate cross entropy using target-posteriors (supports soft labels),.
Implements LossItf.
Definition at line 166 of file nnet-loss.cc.
References Xent::Eval(), KALDI_ASSERT, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), kaldi::nnet1::PosteriorToMatrix(), and Xent::tgt_mat_.
|
virtual |
Generate string with error report,.
Implements LossItf.
Definition at line 182 of file nnet-loss.cc.
References Xent::correct_, Xent::entropy_, Xent::frames_, Xent::loss_vec_, CuVectorBase< Real >::Sum(), VectorBase< Real >::Sum(), and Xent::xentropy_.
Referenced by main().
std::string ReportPerClass | ( | ) |
Generate string with per-class error report,.
Definition at line 203 of file nnet-loss.cc.
References CuVectorBase< Real >::Add(), CuVectorBase< Real >::AddVec(), CuVectorBase< Real >::ApplyPow(), Xent::correct_, Xent::entropy_, Xent::frames_, CuVectorBase< Real >::MulElements(), CuVectorBase< Real >::Scale(), and Xent::xentropy_.
Referenced by main().
|
private |
Definition at line 122 of file nnet-loss.h.
Referenced by Xent::Eval(), Xent::Report(), and Xent::ReportPerClass().
|
private |
Definition at line 131 of file nnet-loss.h.
Referenced by Xent::Eval().
|
private |
Definition at line 124 of file nnet-loss.h.
Referenced by Xent::Eval(), Xent::Report(), and Xent::ReportPerClass().
Definition at line 141 of file nnet-loss.h.
Referenced by Xent::Eval().
|
private |
Definition at line 129 of file nnet-loss.h.
Referenced by Xent::Eval().
Definition at line 134 of file nnet-loss.h.
Referenced by Xent::Eval(), and Mse::Eval().
|
private |
Definition at line 121 of file nnet-loss.h.
Referenced by Xent::Eval(), Mse::Eval(), Xent::Report(), Mse::Report(), and Xent::ReportPerClass().
Definition at line 139 of file nnet-loss.h.
Referenced by Xent::Eval().
|
private |
Definition at line 127 of file nnet-loss.h.
Referenced by Xent::Eval(), and Mse::Eval().
|
private |
Definition at line 130 of file nnet-loss.h.
Referenced by MultiTaskLoss::AvgLoss(), Xent::Eval(), Mse::Eval(), MultiTaskLoss::Eval(), MultiTaskLoss::InitFromString(), Xent::Report(), Mse::Report(), and MultiTaskLoss::Report().
Definition at line 144 of file nnet-loss.h.
Referenced by Xent::Eval().
Definition at line 145 of file nnet-loss.h.
Referenced by Xent::Eval().
Definition at line 135 of file nnet-loss.h.
Referenced by Xent::Eval().
Definition at line 138 of file nnet-loss.h.
Referenced by Xent::Eval(), Mse::Eval(), and MultiTaskLoss::Eval().
|
private |
Definition at line 123 of file nnet-loss.h.
Referenced by Xent::Eval(), Xent::Report(), and Xent::ReportPerClass().
Definition at line 140 of file nnet-loss.h.
Referenced by Xent::Eval().
|
private |
Definition at line 128 of file nnet-loss.h.
Referenced by Xent::Eval().