#include <nnet-chain-example.h>
Public Member Functions | |
NnetChainSupervision () | |
NnetChainSupervision (const std::string &name, const chain::Supervision &supervision, const VectorBase< BaseFloat > &deriv_weights, int32 first_frame, int32 frame_skip) | |
Initialize the object from an object of type chain::Supervision, and some extra information. More... | |
NnetChainSupervision (const NnetChainSupervision &other) | |
void | Write (std::ostream &os, bool binary) const |
void | Read (std::istream &is, bool binary) |
void | Swap (NnetChainSupervision *other) |
void | CheckDim () const |
bool | operator== (const NnetChainSupervision &other) const |
Public Attributes | |
std::string | name |
the name of the output in the neural net; in simple setups it will just be "output". More... | |
std::vector< Index > | indexes |
The indexes that the output corresponds to. More... | |
chain::Supervision | supervision |
The supervision object, containing the FST. More... | |
Vector< BaseFloat > | deriv_weights |
This is a vector of per-frame weights, required to be between 0 and 1, that is applied to the derivative during training (but not during model combination, where the derivatives need to agree with the computed objf values for the optimization code to work). More... | |
Definition at line 43 of file nnet-chain-example.h.
|
inline |
Definition at line 81 of file nnet-chain-example.h.
References NnetChainSupervision::CheckDim(), NnetChainSupervision::operator==(), NnetChainSupervision::Read(), NnetChainSupervision::Swap(), and NnetChainSupervision::Write().
NnetChainSupervision | ( | const std::string & | name, |
const chain::Supervision & | supervision, | ||
const VectorBase< BaseFloat > & | deriv_weights, | ||
int32 | first_frame, | ||
int32 | frame_skip | ||
) |
Initialize the object from an object of type chain::Supervision, and some extra information.
Note: you probably want to set 'name' to "output". 'first_frame' will often be zero but you can choose (just make it consistent with how you numbered your inputs), and 'frame_skip' would be 1 in a vanilla setup, but we plan to try setups where the output periodicity is slower than the input, so in this case it might be 2 or 3.
Definition at line 107 of file nnet-chain-example.cc.
References NnetChainSupervision::CheckDim(), rnnlm::i, NnetChainSupervision::indexes, rnnlm::j, and KALDI_ASSERT.
NnetChainSupervision | ( | const NnetChainSupervision & | other | ) |
Definition at line 92 of file nnet-chain-example.cc.
References NnetChainSupervision::CheckDim().
void CheckDim | ( | ) | const |
Definition at line 65 of file nnet-chain-example.cc.
References NnetChainSupervision::deriv_weights, rnnlm::i, NnetChainSupervision::indexes, rnnlm::j, KALDI_ASSERT, rnnlm::n, and NnetChainSupervision::supervision.
Referenced by kaldi::nnet3::MergeSupervision(), NnetChainSupervision::NnetChainSupervision(), NnetChainSupervision::Read(), NnetChainSupervision::Swap(), and NnetChainSupervision::Write().
bool operator== | ( | const NnetChainSupervision & | other | ) | const |
Definition at line 39 of file nnet-chain-example.cc.
References NnetChainSupervision::deriv_weights, NnetChainSupervision::indexes, NnetChainSupervision::name, and NnetChainSupervision::supervision.
Referenced by NnetChainSupervision::NnetChainSupervision().
void Read | ( | std::istream & | is, |
bool | binary | ||
) |
Definition at line 45 of file nnet-chain-example.cc.
References NnetChainSupervision::CheckDim(), NnetChainSupervision::deriv_weights, kaldi::nnet3::ExpectToken(), NnetChainSupervision::indexes, KALDI_ASSERT, NnetChainSupervision::name, kaldi::nnet3::ReadIndexVector(), kaldi::ReadToken(), kaldi::nnet3::ReadVectorAsChar(), and NnetChainSupervision::supervision.
Referenced by NnetChainSupervision::NnetChainSupervision(), and NnetChainExample::Read().
void Swap | ( | NnetChainSupervision * | other | ) |
Definition at line 98 of file nnet-chain-example.cc.
References NnetChainSupervision::CheckDim(), NnetChainSupervision::deriv_weights, NnetChainSupervision::indexes, NnetChainSupervision::name, kaldi::RandInt(), and NnetChainSupervision::supervision.
Referenced by NnetChainSupervision::NnetChainSupervision().
void Write | ( | std::ostream & | os, |
bool | binary | ||
) | const |
Definition at line 28 of file nnet-chain-example.cc.
References NnetChainSupervision::CheckDim(), NnetChainSupervision::deriv_weights, NnetChainSupervision::indexes, NnetChainSupervision::name, NnetChainSupervision::supervision, kaldi::nnet3::WriteIndexVector(), and kaldi::WriteToken().
Referenced by NnetChainSupervision::NnetChainSupervision().
This is a vector of per-frame weights, required to be between 0 and 1, that is applied to the derivative during training (but not during model combination, where the derivatives need to agree with the computed objf values for the optimization code to work).
The reason for this is to more exactly handle edge effects and to ensure that no frames are 'double-counted'. The order of this vector corresponds to the order of the 'indexes' (i.e. all the first frames, then all the second frames, etc.) If this vector is empty it means we're not applying per-frame weights, so it's equivalent to a vector of all ones. This vector is written to disk compactly as unsigned char.
Definition at line 77 of file nnet-chain-example.h.
Referenced by NnetChainSupervision::CheckDim(), kaldi::nnet3::MergeSupervision(), NnetChainSupervision::operator==(), NnetChainTrainer::ProcessOutputs(), NnetChainSupervision::Read(), NnetChainSupervision::Swap(), and NnetChainSupervision::Write().
std::vector<Index> indexes |
The indexes that the output corresponds to.
The size of this vector will be equal to supervision.num_sequences * supervision.frames_per_sequence. Be careful about the order of these indexes– it is a little confusing. The indexes in the 'index' vector are ordered as: (frame 0 of each sequence); (frame 1 of each sequence); and so on. But in the 'supervision' object, the FST contains (sequence 0; sequence 1; ...). So reordering is needed when doing the numerator computation. We order 'indexes' in this way for efficiency in the denominator computation (it helps memory locality), as well as to avoid the need for the nnet to reorder things internally to match the requested output (for layers inside the neural net, the ordering is (frame 0; frame 1 ...) as this corresponds to the order you get when you sort a vector of Index).
Definition at line 60 of file nnet-chain-example.h.
Referenced by NnetChainSupervision::CheckDim(), kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::MergeSupervision(), NnetChainSupervision::NnetChainSupervision(), NnetChainExampleStructureHasher::operator()(), NnetChainSupervision::operator==(), NnetChainSupervision::Read(), NnetChainSupervision::Swap(), and NnetChainSupervision::Write().
std::string name |
the name of the output in the neural net; in simple setups it will just be "output".
Definition at line 46 of file nnet-chain-example.h.
Referenced by kaldi::nnet3::GetChainComputationRequest(), kaldi::nnet3::MergeSupervision(), NnetChainExampleStructureHasher::operator()(), NnetChainSupervision::operator==(), NnetChainTrainer::ProcessOutputs(), NnetChainComputeProb::ProcessOutputs(), NnetChainSupervision::Read(), NnetChainSupervision::Swap(), and NnetChainSupervision::Write().
chain::Supervision supervision |
The supervision object, containing the FST.
Definition at line 64 of file nnet-chain-example.h.
Referenced by NnetChainSupervision::CheckDim(), kaldi::nnet3::MergeSupervision(), NnetChainSupervision::operator==(), NnetChainTrainer::ProcessOutputs(), NnetChainComputeProb::ProcessOutputs(), NnetChainSupervision::Read(), NnetChainSupervision::Swap(), and NnetChainSupervision::Write().