All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
DiscriminativeComputation Class Reference
Collaboration diagram for DiscriminativeComputation:

Public Member Functions

 DiscriminativeComputation (const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
 
void Compute ()
 

Private Types

typedef Lattice::Arc Arc
 
typedef Arc::StateId StateId
 

Private Member Functions

double ComputeObjfAndDeriv (Posterior *post, Posterior *xent_post)
 
void LookupNnetOutput (std::vector< Int32Pair > *requested_indexes, std::vector< BaseFloat > *answers) const
 
void ConvertAnswersToLogLike (const std::vector< Int32Pair > &requested_indexes, std::vector< BaseFloat > *answers) const
 
void ProcessPosteriors (const Posterior &post, CuMatrixBase< BaseFloat > *output_deriv_temp, double *tot_num_post=NULL, double *tot_den_post=NULL) const
 

Static Private Member Functions

static size_t LatticeAcousticRescore (const std::vector< BaseFloat > &answers, size_t index, Lattice *lat)
 
static Int32Pair MakePair (int32 first, int32 second)
 

Private Attributes

const DiscriminativeOptionsopts_
 
const TransitionModeltmodel_
 
const CuVectorBase< BaseFloat > & log_priors_
 
const DiscriminativeSupervisionsupervision_
 
const CuMatrixBase< BaseFloat > & nnet_output_
 
DiscriminativeObjectiveInfostats_
 
CuMatrixBase< BaseFloat > * nnet_output_deriv_
 
CuMatrixBase< BaseFloat > * xent_output_deriv_
 
Lattice den_lat_
 
std::vector< int32 > silence_phones_
 

Detailed Description

Definition at line 80 of file discriminative-training.cc.

Member Typedef Documentation

typedef Lattice::Arc Arc
private

Definition at line 81 of file discriminative-training.cc.

typedef Arc::StateId StateId
private

Definition at line 82 of file discriminative-training.cc.

Constructor & Destructor Documentation

DiscriminativeComputation ( const DiscriminativeOptions opts,
const TransitionModel tmodel,
const CuVectorBase< BaseFloat > &  log_priors,
const DiscriminativeSupervision supervision,
const CuMatrixBase< BaseFloat > &  nnet_output,
DiscriminativeObjectiveInfo stats,
CuMatrixBase< BaseFloat > *  nnet_output_deriv,
CuMatrixBase< BaseFloat > *  xent_output_deriv 
)

Definition at line 181 of file discriminative-training.cc.

References DiscriminativeSupervision::den_lat, DiscriminativeComputation::den_lat_, KALDI_ERR, DiscriminativeComputation::opts_, DiscriminativeComputation::silence_phones_, DiscriminativeOptions::silence_phones_str, and kaldi::SplitStringToIntegers().

190  : opts_(opts), tmodel_(tmodel), log_priors_(log_priors),
191  supervision_(supervision), nnet_output_(nnet_output),
192  stats_(stats),
193  nnet_output_deriv_(nnet_output_deriv),
194  xent_output_deriv_(xent_output_deriv) {
195 
196  den_lat_ = supervision.den_lat;
197  TopSort(&den_lat_);
198 
200  &silence_phones_)) {
201  KALDI_ERR << "Bad value for --silence-phones option: "
203  }
204 }
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
#define KALDI_ERR
Definition: kaldi-error.h:127

Member Function Documentation

void Compute ( )

Definition at line 353 of file discriminative-training.cc.

References DiscriminativeObjectiveInfo::AccumulateGradients(), DiscriminativeObjectiveInfo::AccumulateOutput(), DiscriminativeObjectiveInfo::Add(), CuVectorBase< Real >::AddDiagMat2(), CuMatrixBase< Real >::AddMat(), CuMatrixBase< Real >::ApplyExp(), DiscriminativeOptions::boost, DiscriminativeComputation::ComputeObjfAndDeriv(), DiscriminativeComputation::ConvertAnswersToLogLike(), DiscriminativeOptions::criterion, DiscriminativeComputation::den_lat_, CuVectorBase< Real >::Dim(), DiscriminativeSupervision::frames_per_sequence, kaldi::GetVerboseLevel(), DiscriminativeObjectiveInfo::gradients, rnnlm::i, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, kaldi::kNoTrans, kaldi::kTrans, DiscriminativeOptions::l2_regularize, DiscriminativeComputation::LatticeAcousticRescore(), kaldi::LatticeBoost(), DiscriminativeComputation::log_priors_, DiscriminativeComputation::LookupNnetOutput(), DiscriminativeComputation::nnet_output_, DiscriminativeComputation::nnet_output_deriv_, DiscriminativeSupervision::num_ali, DiscriminativeSupervision::num_sequences, CuMatrixBase< Real >::NumCols(), CuMatrixBase< Real >::NumRows(), DiscriminativeComputation::opts_, DiscriminativeObjectiveInfo::output, DiscriminativeObjectiveInfo::Print(), DiscriminativeObjectiveInfo::PrintAll(), DiscriminativeComputation::ProcessPosteriors(), DiscriminativeObjectiveInfo::Reset(), CuMatrixBase< Real >::SetZero(), DiscriminativeComputation::silence_phones_, DiscriminativeComputation::stats_, DiscriminativeComputation::supervision_, DiscriminativeComputation::tmodel_, DiscriminativeObjectiveInfo::tot_den_count, DiscriminativeObjectiveInfo::tot_l2_term, DiscriminativeObjectiveInfo::tot_num_count, DiscriminativeObjectiveInfo::tot_num_objf, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t, DiscriminativeObjectiveInfo::tot_t_weighted, DiscriminativeObjectiveInfo::TotalObjf(), kaldi::TraceMatMat(), DiscriminativeSupervision::weight, and DiscriminativeComputation::xent_output_deriv_.

Referenced by kaldi::discriminative::ComputeDiscriminativeObjfAndDeriv().

353  {
354  if (opts_.criterion == "mmi" && opts_.boost != 0.0) {
355  BaseFloat max_silence_error = 0.0;
357  opts_.boost, max_silence_error, &den_lat_);
358  }
359 
361 
362  int32 num_pdfs = nnet_output_.NumCols();
363  KALDI_ASSERT(log_priors_.Dim() == 0 || num_pdfs == log_priors_.Dim());
364 
365  // We need to look up the nnet output for some pdf-ids.
366  // Rather than looking them all up using operator (), which is
367  // very slow because each lookup involves a separate CUDA call with
368  // communication over PciExpress, we look them up all at once using
369  // CuMatrix::Lookup().
370  std::vector<BaseFloat> answers;
371  std::vector<Int32Pair> requested_indexes;
372 
373  LookupNnetOutput(&requested_indexes, &answers);
374 
375  ConvertAnswersToLogLike(requested_indexes, &answers);
376 
377  size_t index = 0;
378 
379  // Now put the negative (scaled) acoustic log-likelihoods in the lattice.
380  index = LatticeAcousticRescore(answers, index, &den_lat_);
381  // index is now the number of indexes of log-likes used to rescore lattice.
382  // This is required to further lookup answers for computing "mmi"
383  // numerator score.
384 
385  // Get statistics for this minibatch
386  DiscriminativeObjectiveInfo this_stats;
387  if (stats_) {
388  this_stats = *stats_;
389  this_stats.Reset();
390  }
391 
392  // Look up numerator probabilities corresponding to alignment
393  if (opts_.criterion == "mmi") {
394  double tot_num_like = 0.0;
395  KALDI_ASSERT(index + supervision_.num_ali.size() == answers.size());
396  for (size_t this_index = 0; this_index < supervision_.num_ali.size(); this_index++) {
397  tot_num_like += answers[index + this_index];
398  }
399  this_stats.tot_num_objf += supervision_.weight * tot_num_like;
400  index += supervision_.num_ali.size();
401  }
402 
403  KALDI_ASSERT(index == answers.size());
404 
405  if (nnet_output_deriv_) {
409  }
410 
411  if (xent_output_deriv_) {
415  }
416 
417  Posterior post;
418  Posterior xent_post;
419  double objf = ComputeObjfAndDeriv(&post,
420  (xent_output_deriv_ ? &xent_post : NULL));
421 
422  this_stats.tot_objf += supervision_.weight * objf;
423 
424  KALDI_ASSERT(nnet_output_.NumRows() == post.size());
425 
426  CuMatrix<BaseFloat> output_deriv;
427 
428  CuMatrixBase<BaseFloat> *output_deriv_temp;
429 
430  if (nnet_output_deriv_)
431  output_deriv_temp = nnet_output_deriv_;
432  else {
433  // This is for accumulating the statistics
434  output_deriv.Resize(nnet_output_.NumRows(), nnet_output_.NumCols());
435  output_deriv_temp = &output_deriv;
436  }
437 
438  double tot_num_post = 0.0, tot_den_post = 0.0;
439  {
440  ProcessPosteriors(post, output_deriv_temp,
441  &tot_num_post, &tot_den_post);
442  }
443 
444  if (xent_output_deriv_) {
445  ProcessPosteriors(xent_post, xent_output_deriv_, NULL, NULL);
446  }
447 
448  this_stats.tot_den_count += tot_den_post;
449  this_stats.tot_num_count += tot_num_post;
450 
451  if (this_stats.AccumulateGradients())
452  (this_stats.gradients).AddRowSumMat(1.0, CuMatrix<double>(*output_deriv_temp));
453 
454  if (this_stats.AccumulateOutput()) {
455  CuMatrix<double> temp(nnet_output_);
456  temp.ApplyExp();
457  (this_stats.output).AddRowSumMat(1.0, temp);
458  }
459 
460  this_stats.tot_t = num_frames;
461  this_stats.tot_t_weighted = num_frames * supervision_.weight;
462 
463  if (!(this_stats.TotalObjf(opts_.criterion) ==
464  this_stats.TotalObjf(opts_.criterion))) {
465  // inf or NaN detected
466  if (nnet_output_deriv_)
468  BaseFloat default_objf = -10;
469  KALDI_WARN << "Objective function is "
470  << this_stats.TotalObjf(opts_.criterion)
471  << ", setting to " << default_objf << " per frame.";
472  this_stats.tot_objf = default_objf * this_stats.tot_t_weighted;
473  }
474 
475  if (GetVerboseLevel() >= 2) {
476  if (GetVerboseLevel() >= 3) {
477  this_stats.PrintAll(opts_.criterion);
478  } else
479  this_stats.Print(opts_.criterion);
480  }
481 
482  // This code helps us see how big the derivatives are, on average,
483  // for different frames of the sequences. As expected, they are
484  // smaller towards the edges of the sequences (due to the penalization
485  // of 'incorrect' pdf-ids.
486  if (nnet_output_deriv_ && GetVerboseLevel() >= 1) {
487  int32 tot_frames = nnet_output_deriv_->NumRows(),
488  frames_per_sequence = supervision_.frames_per_sequence,
489  num_sequences = supervision_.num_sequences;
490  CuVector<BaseFloat> row_products(tot_frames);
491  row_products.AddDiagMat2(1.0, *nnet_output_deriv_, kNoTrans, 0.0);
492  Vector<BaseFloat> row_products_cpu(row_products);
493  Vector<BaseFloat> row_products_per_frame(frames_per_sequence);
494  for (int32 i = 0; i < tot_frames; i++)
495  row_products_per_frame(i / num_sequences) += row_products_cpu(i);
496  KALDI_LOG << "Derivs per frame are " << row_products_per_frame;
497  }
498 
499  if (opts_.l2_regularize != 0.0) {
500  // compute the l2 penalty term and its derivative
502  this_stats.tot_l2_term += -0.5 * scale * TraceMatMat(nnet_output_, nnet_output_, kTrans);
503  if (nnet_output_deriv_)
504  nnet_output_deriv_->AddMat(-1.0 * scale, nnet_output_);
505  }
506 
507  if (stats_)
508  stats_->Add(this_stats);
509 
510 }
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:196
double ComputeObjfAndDeriv(Posterior *post, Posterior *xent_post)
void Add(const DiscriminativeObjectiveInfo &other)
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:195
static size_t LatticeAcousticRescore(const std::vector< BaseFloat > &answers, size_t index, Lattice *lat)
void ConvertAnswersToLogLike(const std::vector< Int32Pair > &requested_indexes, std::vector< BaseFloat > *answers) const
void LookupNnetOutput(std::vector< Int32Pair > *requested_indexes, std::vector< BaseFloat > *answers) const
void SetZero()
Math operations, some calling kernels.
Definition: cu-matrix.cc:474
#define KALDI_WARN
Definition: kaldi-error.h:130
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
Definition: cu-matrix.cc:939
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void ProcessPosteriors(const Posterior &post, CuMatrixBase< BaseFloat > *output_deriv_temp, double *tot_num_post=NULL, double *tot_den_post=NULL) const
bool LatticeBoost(const TransitionModel &trans, const std::vector< int32 > &alignment, const std::vector< int32 > &silence_phones, BaseFloat b, BaseFloat max_silence_error, Lattice *lat)
Boosts LM probabilities by b * [number of frame errors]; equivalently, adds -b*[number of frame error...
#define KALDI_LOG
Definition: kaldi-error.h:133
double ComputeObjfAndDeriv ( Posterior post,
Posterior xent_post 
)
private

Definition at line 512 of file discriminative-training.cc.

References kaldi::AlignmentToPosterior(), kaldi::ConvertPosteriorToPdfs(), DiscriminativeOptions::criterion, DiscriminativeComputation::den_lat_, DiscriminativeOptions::drop_frames, KALDI_ERR, kaldi::LatticeForwardBackwardMmi(), kaldi::LatticeForwardBackwardMpeVariants(), DiscriminativeSupervision::num_ali, DiscriminativeOptions::one_silence_class, DiscriminativeComputation::opts_, DiscriminativeComputation::silence_phones_, DiscriminativeComputation::supervision_, and DiscriminativeComputation::tmodel_.

Referenced by DiscriminativeComputation::Compute().

513  {
514 
515  if (xent_post) {
516  Posterior tid_post;
517  // Compute posterior from the numerator alignment
519  ConvertPosteriorToPdfs(tmodel_, tid_post, xent_post);
520  }
521 
522  if (opts_.criterion == "mpfe" || opts_.criterion == "smbr") {
523  Posterior tid_post;
525  den_lat_,
528  &tid_post);
529  ConvertPosteriorToPdfs(tmodel_, tid_post, post);
530  return ans;
531  } else if (opts_.criterion == "mmi") {
532  bool convert_to_pdfs = true, cancel = true;
533  // we'll return the denominator-lattice forward backward likelihood,
534  // which is one term in the objective function.
536  opts_.drop_frames, convert_to_pdfs,
537  cancel, post));
538  } else {
539  KALDI_ERR << "Unknown criterion " << opts_.criterion;
540  }
541 
542  return 0;
543 }
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
BaseFloat LatticeForwardBackwardMmi(const TransitionModel &tmodel, const Lattice &lat, const std::vector< int32 > &num_ali, bool drop_frames, bool convert_to_pdf_ids, bool cancel, Posterior *post)
This function can be used to compute posteriors for MMI, with a positive contribution for the numerat...
void AlignmentToPosterior(const std::vector< int32 > &ali, Posterior *post)
Convert an alignment to a posterior (with a scale of 1.0 on each entry).
Definition: posterior.cc:290
#define KALDI_ERR
Definition: kaldi-error.h:127
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322
void ConvertAnswersToLogLike ( const std::vector< Int32Pair > &  requested_indexes,
std::vector< BaseFloat > *  answers 
) const
private

Definition at line 263 of file discriminative-training.cc.

References DiscriminativeOptions::acoustic_scale, CuVectorBase< Real >::Dim(), KALDI_ASSERT, KALDI_ISINF, KALDI_ISNAN, KALDI_WARN, kaldi::Log(), DiscriminativeComputation::log_priors_, and DiscriminativeComputation::opts_.

Referenced by DiscriminativeComputation::Compute().

265  {
266  int32 num_floored = 0;
267 
268  BaseFloat floor_val = -20 * kaldi::Log(10.0); // floor for posteriors.
269  size_t index;
270 
271  Vector<BaseFloat> log_priors(log_priors_);
272 
273  // Replace "answers" with the vector of scaled log-probs. If this step takes
274  // too much time, we can look at other ways to do it, using the CUDA card.
275  for (index = 0; index < answers->size(); index++) {
276  BaseFloat log_post = (*answers)[index];
277  if (log_post < floor_val) {
278  // TODO: this might not be required for 'chain' models
279  log_post = floor_val;
280  num_floored++;
281  }
282 
283  if (log_priors_.Dim() > 0) {
284  int32 pdf_id = requested_indexes[index].second;
285  KALDI_ASSERT(log_post <= 0 && log_priors(pdf_id) <= 0);
286  BaseFloat pseudo_loglike = (log_post - log_priors(pdf_id))
288  KALDI_ASSERT(!KALDI_ISINF(pseudo_loglike) && !KALDI_ISNAN(pseudo_loglike));
289  (*answers)[index] = pseudo_loglike;
290  } else {
291  (*answers)[index] = log_post * opts_.acoustic_scale;
292  }
293  }
294 
295  if (num_floored > 0) {
296  KALDI_WARN << "Floored " << num_floored << " probabilities from nnet.";
297  }
298 }
#define KALDI_ISINF
Definition: kaldi-math.h:73
float BaseFloat
Definition: kaldi-types.h:29
double Log(double x)
Definition: kaldi-math.h:100
MatrixIndexT Dim() const
Dimensions.
Definition: cu-vector.h:67
#define KALDI_WARN
Definition: kaldi-error.h:130
#define KALDI_ISNAN
Definition: kaldi-math.h:72
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
size_t LatticeAcousticRescore ( const std::vector< BaseFloat > &  answers,
size_t  index,
Lattice lat 
)
staticprivate

Definition at line 300 of file discriminative-training.cc.

References LatticeWeightTpl< FloatType >::SetValue2(), and LatticeWeightTpl< BaseFloat >::Zero().

Referenced by DiscriminativeComputation::Compute().

302  {
303  int32 num_states = lat->NumStates();
304 
305  for (StateId s = 0; s < num_states; s++) {
306  for (fst::MutableArcIterator<Lattice> aiter(lat, s);
307  !aiter.Done(); aiter.Next()) {
308  Arc arc = aiter.Value();
309  if (arc.ilabel != 0) { // input-side has transition-ids, output-side empty
310  arc.weight.SetValue2(-answers[index]);
311  index++;
312  aiter.SetValue(arc);
313  }
314  }
315  LatticeWeight final = lat->Final(s);
316  if (final != LatticeWeight::Zero()) {
317  final.SetValue2(0.0); // make sure no acoustic term in final-prob.
318  lat->SetFinal(s, final);
319  }
320  }
321 
322  // Number of indexes of log-likes used to rescore lattice
323  return index;
324 }
static const LatticeWeightTpl Zero()
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
void LookupNnetOutput ( std::vector< Int32Pair > *  requested_indexes,
std::vector< BaseFloat > *  answers 
) const
private

Definition at line 206 of file discriminative-training.cc.

References DiscriminativeOptions::criterion, DiscriminativeComputation::den_lat_, DiscriminativeSupervision::frames_per_sequence, KALDI_ASSERT, kaldi::LatticeStateTimes(), CuMatrixBase< Real >::Lookup(), DiscriminativeComputation::MakePair(), DiscriminativeComputation::nnet_output_, DiscriminativeSupervision::num_ali, DiscriminativeSupervision::num_sequences, TransitionModel::NumPdfs(), DiscriminativeComputation::opts_, DiscriminativeComputation::supervision_, DiscriminativeComputation::tmodel_, and TransitionModel::TransitionIdToPdf().

Referenced by DiscriminativeComputation::Compute().

208  {
209  BaseFloat wiggle_room = 1.3; // value not critical.. it's just 'reserve'
210 
212  int32 num_pdfs = tmodel_.NumPdfs();
213 
214  int32 num_reserve = wiggle_room * den_lat_.NumStates();
215 
216  if (opts_.criterion == "mmi") {
217  // For looking up the posteriors corresponding to the pdfs in the alignment
218  num_reserve += num_frames;
219  }
220 
221  requested_indexes->reserve(num_reserve);
222 
223  // Denominator probabilities to look up from denominator lattice
224  std::vector<int32> state_times;
225  int32 T = LatticeStateTimes(den_lat_, &state_times);
226  KALDI_ASSERT(T == num_frames);
227 
228  StateId num_states = den_lat_.NumStates();
229  for (StateId s = 0; s < num_states; s++) {
230  int32 t = state_times[s];
231  int32 seq = t / supervision_.frames_per_sequence,
233 
234  for (fst::ArcIterator<Lattice> aiter(den_lat_, s); !aiter.Done(); aiter.Next()) {
235  const Arc &arc = aiter.Value();
236  if (arc.ilabel != 0) { // input-side has transition-ids, output-side empty
237  int32 tid = arc.ilabel, pdf_id = tmodel_.TransitionIdToPdf(tid);
238  // The ordering of the indexes is similar to that in chain models
239  requested_indexes->push_back(MakePair(idx * supervision_.num_sequences + seq, pdf_id));
240  }
241  }
242  }
243 
244  if (opts_.criterion == "mmi") {
245  // Numerator probabilities to look up from alignment
246  for (int32 t = 0; t < num_frames; t++) {
247  int32 seq = t / supervision_.frames_per_sequence,
249  int32 tid = supervision_.num_ali[t],
250  pdf_id = tmodel_.TransitionIdToPdf(tid);
251  KALDI_ASSERT(pdf_id >= 0 && pdf_id < num_pdfs);
252  requested_indexes->push_back(MakePair(idx * supervision_.num_sequences + seq, pdf_id));
253  }
254  }
255 
256  CuArray<Int32Pair> cu_requested_indexes(*requested_indexes);
257  answers->resize(requested_indexes->size());
258  nnet_output_.Lookup(cu_requested_indexes, &((*answers)[0]));
259  // requested_indexes now contain (t, j) pair and answers contains the
260  // neural network output, which is log p(j|x(t)) for CE models
261 }
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
int32 TransitionIdToPdf(int32 trans_id) const
static Int32Pair MakePair(int32 first, int32 second)
float BaseFloat
Definition: kaldi-types.h:29
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:169
void Lookup(const std::vector< Int32Pair > &indexes, Real *output) const
Definition: cu-matrix.cc:3064
static Int32Pair MakePair ( int32  first,
int32  second 
)
inlinestaticprivate

Definition at line 173 of file discriminative-training.cc.

References Int32Pair::first, and Int32Pair::second.

Referenced by DiscriminativeComputation::LookupNnetOutput(), and DiscriminativeComputation::ProcessPosteriors().

173  {
174  Int32Pair ans;
175  ans.first = first;
176  ans.second = second;
177  return ans;
178  }
int32_cuda second
Definition: cu-matrixdim.h:86
int32_cuda first
Definition: cu-matrixdim.h:85
void ProcessPosteriors ( const Posterior post,
CuMatrixBase< BaseFloat > *  output_deriv_temp,
double *  tot_num_post = NULL,
double *  tot_den_post = NULL 
) const
private

Definition at line 326 of file discriminative-training.cc.

References CuMatrixBase< Real >::AddElements(), DiscriminativeSupervision::frames_per_sequence, rnnlm::j, DiscriminativeComputation::MakePair(), DiscriminativeSupervision::num_sequences, DiscriminativeComputation::supervision_, and DiscriminativeSupervision::weight.

Referenced by DiscriminativeComputation::Compute().

330  {
331  std::vector<Int32Pair> deriv_indexes;
332  std::vector<BaseFloat> deriv_data;
333  for (size_t t = 0; t < post.size(); t++) {
334  for (size_t j = 0; j < post[t].size(); j++) {
335  int32 seq = t / supervision_.frames_per_sequence,
337  int32 pdf_id = post[t][j].first;
338 
339  // Same ordering as for 'chain' models
340  deriv_indexes.push_back(MakePair(idx * supervision_.num_sequences + seq, pdf_id));
341 
342  BaseFloat weight = post[t][j].second;
343  if (tot_num_post && weight > 0.0) *tot_num_post += weight;
344  if (tot_den_post && weight < 0.0) *tot_den_post -= weight;
345  deriv_data.push_back(weight);
346  }
347  }
348  CuArray<Int32Pair> cu_deriv_indexes(deriv_indexes);
349  output_deriv_temp->AddElements(supervision_.weight, cu_deriv_indexes,
350  deriv_data.data());
351 }
static Int32Pair MakePair(int32 first, int32 second)
void AddElements(Real alpha, const std::vector< MatrixElement< Real > > &input)
Definition: cu-matrix.cc:2998
float BaseFloat
Definition: kaldi-types.h:29

Member Data Documentation

CuMatrixBase<BaseFloat>* nnet_output_deriv_
private

Definition at line 129 of file discriminative-training.cc.

Referenced by DiscriminativeComputation::Compute().

DiscriminativeObjectiveInfo* stats_
private

Definition at line 126 of file discriminative-training.cc.

Referenced by DiscriminativeComputation::Compute().

CuMatrixBase<BaseFloat>* xent_output_deriv_
private

Definition at line 135 of file discriminative-training.cc.

Referenced by DiscriminativeComputation::Compute().


The documentation for this class was generated from the following file: