#include <logistic-regression.h>
Public Member Functions | |
void | Train (const Matrix< BaseFloat > &xs, const std::vector< int32 > &ys, const LogisticRegressionConfig &conf) |
void | GetLogPosteriors (const Matrix< BaseFloat > &xs, Matrix< BaseFloat > *log_posteriors) |
void | GetLogPosteriors (const Vector< BaseFloat > &x, Vector< BaseFloat > *log_posteriors) |
void | Write (std::ostream &os, bool binary) const |
void | Read (std::istream &is, bool binary) |
void | ScalePriors (const Vector< BaseFloat > &prior_scales) |
Protected Member Functions | |
void friend | UnitTestTrain () |
void friend | UnitTestPosteriors () |
Private Member Functions | |
BaseFloat | DoStep (const Matrix< BaseFloat > &xs, Matrix< BaseFloat > *xw, const std::vector< int32 > &ys, OptimizeLbfgs< BaseFloat > *lbfgs, BaseFloat normalizer) |
void | TrainParameters (const Matrix< BaseFloat > &xs, const std::vector< int32 > &ys, const LogisticRegressionConfig &conf, Matrix< BaseFloat > *xw) |
void | MixUp (const std::vector< int32 > &ys, const int32 &num_classes, const LogisticRegressionConfig &conf) |
BaseFloat | GetObjfAndGrad (const Matrix< BaseFloat > &xs, const std::vector< int32 > &ys, const Matrix< BaseFloat > &xw, Matrix< BaseFloat > *grad, BaseFloat normalizer) |
void | SetWeights (const Matrix< BaseFloat > &weights, const std::vector< int32 > classes) |
Private Attributes | |
Matrix< BaseFloat > | weights_ |
std::vector< int32 > | class_ |
Definition at line 52 of file logistic-regression.h.
|
private |
Definition at line 196 of file logistic-regression.cc.
References MatrixBase< Real >::AddMatMat(), MatrixBase< Real >::CopyRowsFromVec(), OptimizeLbfgs< Real >::DoStep(), LogisticRegression::GetObjfAndGrad(), OptimizeLbfgs< Real >::GetProposedValue(), KALDI_LOG, kaldi::kNoTrans, kaldi::kTrans, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and LogisticRegression::weights_.
Referenced by LogisticRegression::TrainParameters().
Definition at line 137 of file logistic-regression.cc.
References MatrixBase< Real >::AddMatMat(), LogisticRegression::class_, MatrixBase< Real >::CopyFromMat(), rnnlm::i, rnnlm::j, kaldi::kNoTrans, kaldi::kTrans, kaldi::LogAdd(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), Matrix< Real >::Resize(), MatrixBase< Real >::Row(), MatrixBase< Real >::Set(), and LogisticRegression::weights_.
Referenced by ComputeLogPosteriors(), ComputeScores(), and kaldi::UnitTestPosteriors().
Definition at line 171 of file logistic-regression.cc.
References VectorBase< Real >::Add(), VectorBase< Real >::AddMatVec(), LogisticRegression::class_, VectorBase< Real >::CopyFromVec(), VectorBase< Real >::Dim(), rnnlm::i, rnnlm::j, kaldi::kNoTrans, kaldi::LogAdd(), VectorBase< Real >::LogSumExp(), MatrixBase< Real >::NumRows(), Vector< Real >::Resize(), VectorBase< Real >::Set(), and LogisticRegression::weights_.
|
private |
Definition at line 225 of file logistic-regression.cc.
References MatrixBase< Real >::AddMat(), LogisticRegression::class_, VectorBase< Real >::CopyFromVec(), rnnlm::i, rnnlm::j, KALDI_VLOG, kaldi::kTrans, kaldi::Log(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), MatrixBase< Real >::Row(), MatrixBase< Real >::Scale(), kaldi::TraceMatMat(), and LogisticRegression::weights_.
Referenced by LogisticRegression::DoStep(), and kaldi::UnitTestTrain().
|
private |
Definition at line 71 of file logistic-regression.cc.
References LogisticRegression::class_, MatrixBase< Real >::CopyFromMat(), kaldi::GetSplitTargets(), rnnlm::i, rnnlm::j, KALDI_LOG, LogisticRegressionConfig::mix_up, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), LogisticRegressionConfig::power, Matrix< Real >::Resize(), MatrixBase< Real >::Row(), VectorBase< Real >::SetRandn(), and LogisticRegression::weights_.
Referenced by LogisticRegression::Train().
void Read | ( | std::istream & | is, |
bool | binary | ||
) |
Definition at line 301 of file logistic-regression.cc.
References LogisticRegression::class_, kaldi::ExpectToken(), rnnlm::i, MatrixBase< Real >::NumRows(), Matrix< Real >::Read(), kaldi::ReadIntegerVector(), kaldi::ReadToken(), and LogisticRegression::weights_.
Definition at line 284 of file logistic-regression.cc.
References VectorBase< Real >::ApplyLog(), LogisticRegression::class_, rnnlm::i, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and LogisticRegression::weights_.
Referenced by main().
Definition at line 275 of file logistic-regression.cc.
References LogisticRegression::class_, MatrixBase< Real >::CopyFromMat(), rnnlm::i, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), Matrix< Real >::Resize(), and LogisticRegression::weights_.
Referenced by kaldi::UnitTestPosteriors().
void Train | ( | const Matrix< BaseFloat > & | xs, |
const std::vector< int32 > & | ys, | ||
const LogisticRegressionConfig & | conf | ||
) |
Definition at line 27 of file logistic-regression.cc.
References LogisticRegression::class_, MatrixBase< Real >::CopyFromMat(), rnnlm::i, KALDI_ASSERT, KALDI_LOG, LogisticRegressionConfig::mix_up, LogisticRegression::MixUp(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), Matrix< Real >::Resize(), MatrixBase< Real >::SetZero(), LogisticRegression::TrainParameters(), and LogisticRegression::weights_.
Referenced by main(), and kaldi::UnitTestTrain().
|
private |
Definition at line 117 of file logistic-regression.cc.
References VectorBase< Real >::CopyRowsFromMat(), MatrixBase< Real >::CopyRowsFromVec(), LogisticRegression::DoStep(), OptimizeLbfgs< Real >::GetValue(), LogisticRegressionConfig::max_steps, LbfgsOptions::minimize, LogisticRegressionConfig::normalizer, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and LogisticRegression::weights_.
Referenced by LogisticRegression::Train().
|
protected |
|
protected |
void Write | ( | std::ostream & | os, |
bool | binary | ||
) | const |
Definition at line 292 of file logistic-regression.cc.
References LogisticRegression::class_, LogisticRegression::weights_, MatrixBase< Real >::Write(), kaldi::WriteIntegerVector(), and kaldi::WriteToken().
|
private |
Definition at line 121 of file logistic-regression.h.
Referenced by LogisticRegression::GetLogPosteriors(), LogisticRegression::GetObjfAndGrad(), LogisticRegression::MixUp(), LogisticRegression::Read(), LogisticRegression::ScalePriors(), LogisticRegression::SetWeights(), LogisticRegression::Train(), and LogisticRegression::Write().
Definition at line 117 of file logistic-regression.h.
Referenced by LogisticRegression::DoStep(), LogisticRegression::GetLogPosteriors(), LogisticRegression::GetObjfAndGrad(), LogisticRegression::MixUp(), LogisticRegression::Read(), LogisticRegression::ScalePriors(), LogisticRegression::SetWeights(), LogisticRegression::Train(), LogisticRegression::TrainParameters(), kaldi::UnitTestTrain(), and LogisticRegression::Write().