Public Member Functions | |
NnetRescaler (const NnetRescaleConfig &config, const std::vector< NnetExample > &examples, Nnet *nnet) | |
void | Rescale () |
Private Member Functions | |
void | FormatInput (const std::vector< NnetExample > &data, CuMatrix< BaseFloat > *input) |
takes the input and formats as a single matrix, in forward_data_[0]. More... | |
void | RescaleComponent (int32 c, int32 num_chunks, CuMatrixBase< BaseFloat > *cur_data_in, CuMatrix< BaseFloat > *next_data) |
void | ComputeRelevantIndexes () |
BaseFloat | GetTargetAvgDeriv (int32 c) |
Private Attributes | |
const NnetRescaleConfig & | config_ |
const std::vector< NnetExample > & | examples_ |
Nnet * | nnet_ |
std::vector< ChunkInfo > | chunk_info_out_ |
std::set< int32 > | relevant_indexes_ |
Definition at line 26 of file rescale-nnet.cc.
|
inline |
Definition at line 28 of file rescale-nnet.cc.
References NnetRescaler::ComputeRelevantIndexes(), NnetRescaler::FormatInput(), NnetRescaler::GetTargetAvgDeriv(), NnetRescaler::Rescale(), and NnetRescaler::RescaleComponent().
|
private |
Definition at line 89 of file rescale-nnet.cc.
References Nnet::GetComponent(), NnetRescaler::nnet_, Nnet::NumComponents(), and NnetRescaler::relevant_indexes_.
Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::Rescale().
|
private |
takes the input and formats as a single matrix, in forward_data_[0].
Definition at line 56 of file rescale-nnet.cc.
References NnetRescaler::chunk_info_out_, Nnet::ComputeChunkInfo(), CuMatrixBase< Real >::CopyFromMat(), CuMatrixBase< Real >::CopyRowsFromVec(), Nnet::InputDim(), KALDI_ASSERT, Nnet::LeftContext(), NnetRescaler::nnet_, CuMatrix< Real >::Resize(), and Nnet::RightContext().
Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::Rescale().
Definition at line 98 of file rescale-nnet.cc.
References NnetRescaler::config_, Nnet::GetComponent(), KALDI_ASSERT, KALDI_ERR, NnetRescaler::nnet_, NnetRescaler::relevant_indexes_, NnetRescaleConfig::target_avg_deriv, NnetRescaleConfig::target_first_layer_avg_deriv, and NnetRescaleConfig::target_last_layer_avg_deriv.
Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::RescaleComponent().
void Rescale | ( | ) |
Definition at line 200 of file rescale-nnet.cc.
References NnetRescaler::chunk_info_out_, NnetRescaler::ComputeRelevantIndexes(), NnetRescaler::examples_, NnetRescaler::FormatInput(), Nnet::GetComponent(), NnetRescaler::nnet_, Nnet::NumComponents(), Component::Propagate(), NnetRescaler::relevant_indexes_, NnetRescaler::RescaleComponent(), and CuMatrix< Real >::Swap().
Referenced by NnetRescaler::NnetRescaler(), and kaldi::nnet2::RescaleNnet().
|
private |
Definition at line 121 of file rescale-nnet.cc.
References Component::Backprop(), NnetRescaler::chunk_info_out_, NnetRescaler::config_, NnetRescaleConfig::delta, Nnet::GetComponent(), NnetRescaler::GetTargetAvgDeriv(), KALDI_ASSERT, KALDI_ERR, KALDI_LOG, KALDI_VLOG, NnetRescaleConfig::max_change, NnetRescaleConfig::min_change, NnetRescaler::nnet_, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), Component::Propagate(), UpdatableComponent::Scale(), and CuMatrixBase< Real >::Sum().
Referenced by NnetRescaler::NnetRescaler(), and NnetRescaler::Rescale().
|
private |
Definition at line 50 of file rescale-nnet.cc.
Referenced by NnetRescaler::FormatInput(), NnetRescaler::Rescale(), and NnetRescaler::RescaleComponent().
|
private |
Definition at line 47 of file rescale-nnet.cc.
Referenced by NnetRescaler::GetTargetAvgDeriv(), and NnetRescaler::RescaleComponent().
|
private |
Definition at line 48 of file rescale-nnet.cc.
Referenced by NnetRescaler::Rescale().
|
private |
Definition at line 49 of file rescale-nnet.cc.
Referenced by NnetRescaler::ComputeRelevantIndexes(), NnetRescaler::FormatInput(), NnetRescaler::GetTargetAvgDeriv(), NnetRescaler::Rescale(), and NnetRescaler::RescaleComponent().
|
private |
Definition at line 51 of file rescale-nnet.cc.
Referenced by NnetRescaler::ComputeRelevantIndexes(), NnetRescaler::GetTargetAvgDeriv(), and NnetRescaler::Rescale().