26 namespace discriminative {
29 std::memset((
void *)
this, 0,
sizeof(*
this));
150 void LookupNnetOutput(std::vector<Int32Pair> *requested_indexes,
151 std::vector<BaseFloat> *answers)
const ;
155 void ConvertAnswersToLogLike(
156 const std::vector<Int32Pair>& requested_indexes,
157 std::vector<BaseFloat> *answers)
const;
168 void ProcessPosteriors(
const Posterior &post,
170 double *tot_num_post = NULL,
171 double *tot_den_post = NULL)
const;
190 : opts_(opts), tmodel_(tmodel), log_priors_(log_priors),
191 supervision_(supervision), nnet_output_(nnet_output),
193 nnet_output_deriv_(nnet_output_deriv),
194 xent_output_deriv_(xent_output_deriv) {
201 KALDI_ERR <<
"Bad value for --silence-phones option: " 207 std::vector<Int32Pair> *requested_indexes,
208 std::vector<BaseFloat> *answers)
const {
218 num_reserve += num_frames;
221 requested_indexes->reserve(num_reserve);
224 std::vector<int32> state_times;
229 for (
StateId s = 0; s < num_states; s++) {
230 int32 t = state_times[s];
234 for (fst::ArcIterator<Lattice> aiter(
den_lat_, s); !aiter.Done(); aiter.Next()) {
235 const Arc &arc = aiter.Value();
236 if (arc.ilabel != 0) {
246 for (
int32 t = 0; t < num_frames; t++) {
257 answers->resize(requested_indexes->size());
258 nnet_output_.Lookup(cu_requested_indexes, &((*answers)[0]));
264 const std::vector<Int32Pair>& requested_indexes,
265 std::vector<BaseFloat> *answers)
const {
266 int32 num_floored = 0;
275 for (index = 0; index < answers->size(); index++) {
277 if (log_post < floor_val) {
279 log_post = floor_val;
284 int32 pdf_id = requested_indexes[index].second;
286 BaseFloat pseudo_loglike = (log_post - log_priors(pdf_id))
289 (*answers)[index] = pseudo_loglike;
295 if (num_floored > 0) {
296 KALDI_WARN <<
"Floored " << num_floored <<
" probabilities from nnet.";
301 const std::vector<BaseFloat> &answers,
303 int32 num_states = lat->NumStates();
305 for (
StateId s = 0; s < num_states; s++) {
306 for (fst::MutableArcIterator<Lattice> aiter(lat, s);
307 !aiter.Done(); aiter.Next()) {
308 Arc arc = aiter.Value();
309 if (arc.ilabel != 0) {
310 arc.weight.SetValue2(-answers[index]);
318 lat->SetFinal(s,
final);
329 double *tot_num_post,
330 double *tot_den_post)
const {
331 std::vector<Int32Pair> deriv_indexes;
332 std::vector<BaseFloat> deriv_data;
333 for (
size_t t = 0; t < post.size(); t++) {
334 for (
size_t j = 0;
j < post[t].size();
j++) {
337 int32 pdf_id = post[t][
j].first;
343 if (tot_num_post && weight > 0.0) *tot_num_post += weight;
344 if (tot_den_post && weight < 0.0) *tot_den_post -= weight;
345 deriv_data.push_back(weight);
370 std::vector<BaseFloat> answers;
371 std::vector<Int32Pair> requested_indexes;
394 double tot_num_like = 0.0;
397 tot_num_like += answers[index + this_index];
435 output_deriv_temp = &output_deriv;
438 double tot_num_post = 0.0, tot_den_post = 0.0;
441 &tot_num_post, &tot_den_post);
457 (this_stats.
output).AddRowSumMat(1.0, temp);
460 this_stats.
tot_t = num_frames;
471 <<
", setting to " << default_objf <<
" per frame.";
494 for (
int32 i = 0;
i < tot_frames;
i++)
495 row_products_per_frame(
i / num_sequences) += row_products_cpu(
i);
496 KALDI_LOG <<
"Derivs per frame are " << row_products_per_frame;
532 bool convert_to_pdfs =
true, cancel =
true;
556 nnet_output_deriv, xent_output_deriv);
561 tot_t += other.
tot_t;
569 if (AccumulateGradients()) {
572 if (AccumulateOutput()) {
573 output.AddVec(1.0, other.
output);
578 bool print_avg_gradients,
579 bool print_avg_output)
const {
580 if (criterion ==
"mmi") {
581 double num_objf = tot_num_objf / tot_t_weighted,
582 den_objf = tot_objf / tot_t_weighted;
583 double objf = num_objf - den_objf;
585 double avg_post_per_frame = tot_num_count / tot_t_weighted;
587 KALDI_LOG <<
"Number of frames is " << tot_t
588 <<
" (weighted: " << tot_t_weighted
589 <<
"), average (num or den) posterior per frame is " 590 << avg_post_per_frame;
591 KALDI_LOG <<
"MMI objective function is " << num_objf <<
" - " 592 << den_objf <<
" = " << objf <<
" per frame, over " 593 << tot_t_weighted <<
" frames.";
594 }
else if (criterion ==
"mpfe") {
595 double avg_gradients = (tot_num_count + tot_den_count) / tot_t_weighted;
596 double objf = tot_objf / tot_t_weighted;
597 KALDI_LOG <<
"Average num+den count of MPFE stats is " << avg_gradients
598 <<
" per frame, over " 599 << tot_t_weighted <<
" frames";
600 KALDI_LOG <<
"MPFE objective function is " << objf
601 <<
" per frame, over " << tot_t_weighted <<
" frames.";
602 }
else if (criterion ==
"smbr") {
603 double avg_gradients = (tot_num_count + tot_den_count) / tot_t_weighted;
604 double objf = tot_objf / tot_t_weighted;
605 KALDI_LOG <<
"Average num+den count of SMBR stats is " << avg_gradients
606 <<
" per frame, over " 607 << tot_t_weighted <<
" frames";
608 KALDI_LOG <<
"SMBR objective function is " << objf
609 <<
" per frame, over " << tot_t_weighted <<
" frames.";
612 if (AccumulateGradients()) {
614 temp.
Scale(1.0/tot_t_weighted);
615 if (print_avg_gradients) {
616 KALDI_LOG <<
"Vector of average gradients wrt output activations is: \n" << temp;
618 KALDI_VLOG(4) <<
"Vector of average gradients wrt output activations is: \n" << temp;
621 if (AccumulateOutput()) {
623 temp.
Scale(1.0/tot_t_weighted);
624 if (print_avg_output) {
625 KALDI_LOG <<
"Average DNN output is: \n" << temp;
627 KALDI_VLOG(4) <<
"Average DNN output is: \n" << temp;
633 if (pdf_id < gradients.Dim() && pdf_id >= 0) {
634 KALDI_LOG <<
"Average gradient wrt output activations of pdf " << pdf_id
635 <<
" is " << gradients(pdf_id) / tot_t_weighted
636 <<
" per frame, over " 637 << tot_t_weighted <<
" frames";
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void LookupNnetOutput(std::vector< Int32Pair > *requested_indexes, std::vector< BaseFloat > *answers) const
void SetZero()
Math operations.
void ConvertAnswersToLogLike(const std::vector< Int32Pair > &requested_indexes, std::vector< BaseFloat > *answers) const
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
const DiscriminativeSupervision & supervision_
void PrintAvgGradientForPdf(int32 pdf_id) const
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
const DiscriminativeOptions & opts_
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void AddElements(Real alpha, const std::vector< MatrixElement< Real > > &input)
void PrintAll(const std::string &criterion) const
void AddDiagMat2(Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType trans, Real beta)
Add the diagonal of a matrix times itself: *this = diag(M M^T) + beta * *this (if trans == kNoTrans)...
DiscriminativeComputation(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
void LatticeAcousticRescore(const TransitionModel &trans_model, const Matrix< BaseFloat > &log_likes, const std::vector< int32 > &state_times, Lattice *lat)
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
double ComputeObjfAndDeriv(Posterior *post, Posterior *xent_post)
int32 TransitionIdToPdf(int32 trans_id) const
const CuVectorBase< BaseFloat > & log_priors_
static Int32Pair MakePair(int32 first, int32 second)
CuVector< double > output
void Add(const DiscriminativeObjectiveInfo &other)
double TotalObjf(const std::string &criterion) const
const TransitionModel & tmodel_
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
BaseFloat LatticeForwardBackwardMmi(const TransitionModel &tmodel, const Lattice &lat, const std::vector< int32 > &num_ali, bool drop_frames, bool convert_to_pdf_ids, bool cancel, Posterior *post)
This function can be used to compute posteriors for MMI, with a positive contribution for the numerat...
bool AccumulateGradients() const
CuVector< double > gradients
static size_t LatticeAcousticRescore(const std::vector< BaseFloat > &answers, size_t index, Lattice *lat)
std::vector< int32 > num_ali
void ComputeDiscriminativeObjfAndDeriv(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
This function does forward-backward on the numerator and denominator lattices and computes derivates ...
static const LatticeWeightTpl Zero()
void AlignmentToPosterior(const std::vector< int32 > &ali, Posterior *post)
Convert an alignment to a posterior (with a scale of 1.0 on each entry).
fst::VectorFst< LatticeArc > Lattice
void Resize(MatrixIndexT dim, MatrixResizeType t=kSetZero)
Allocate the memory.
int32 frames_per_sequence
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
CuMatrixBase< BaseFloat > * xent_output_deriv_
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
void Configure(const DiscriminativeOptions &opts)
CuMatrixBase< BaseFloat > * nnet_output_deriv_
void Scale(Real alpha)
Multiplies all elements by this constant.
DiscriminativeObjectiveInfo()
void Print(const std::string &criterion, bool print_avg_gradients=false, bool print_avg_output=false) const
bool accumulate_gradients
bool AccumulateOutput() const
const CuMatrixBase< BaseFloat > & nnet_output_
Matrix for CUDA computing.
A class representing a vector.
std::string silence_phones_str
#define KALDI_ASSERT(cond)
bool LatticeBoost(const TransitionModel &trans, const std::vector< int32 > &alignment, const std::vector< int32 > &silence_phones, BaseFloat b, BaseFloat max_silence_error, Lattice *lat)
Boosts LM probabilities by b * [number of frame errors]; equivalently, adds -b*[number of frame error...
bool accumulate_gradients
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
std::vector< int32 > silence_phones_
DiscriminativeObjectiveInfo * stats_
Vector for CUDA computing.
void ProcessPosteriors(const Posterior &post, CuMatrixBase< BaseFloat > *output_deriv_temp, double *tot_num_post=NULL, double *tot_den_post=NULL) const