NnetDiscriminativeUpdater Class Reference
Collaboration diagram for NnetDiscriminativeUpdater:

Public Member Functions

 NnetDiscriminativeUpdater (const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
 
void Update ()
 
void Propagate ()
 The forward-through-the-layers part of the computation. More...
 
void LatticeComputations ()
 Does the parts between Propagate() and Backprop(), that involve forward-backward over the lattice. More...
 
void Backprop ()
 
double GetDiscriminativePosteriors (Posterior *post)
 Assuming the lattice already has the correct scores in it, this function does the MPE or MMI forward-backward and puts the resulting discriminative posteriors (which may have positive or negative weight) into "post". More...
 
SubMatrix< BaseFloatGetInputFeatures () const
 
CuMatrixBase< BaseFloat > & GetOutput ()
 

Static Public Member Functions

static Int32Pair MakePair (int32 first, int32 second)
 

Private Types

typedef LatticeArc Arc
 
typedef Arc::StateId StateId
 

Private Attributes

const AmNnetam_nnet_
 
const TransitionModeltmodel_
 
const NnetDiscriminativeUpdateOptionsopts_
 
const DiscriminativeNnetExampleeg_
 
Nnetnnet_to_update_
 
NnetDiscriminativeStatsstats_
 
std::vector< ChunkInfochunk_info_out_
 
std::vector< CuMatrix< BaseFloat > > forward_data_
 
Lattice lat_
 
CuMatrix< BaseFloatbackward_data_
 
std::vector< int32silence_phones_
 

Detailed Description

Definition at line 32 of file nnet-compute-discriminative.cc.

Member Typedef Documentation

◆ Arc

typedef LatticeArc Arc
private

Definition at line 78 of file nnet-compute-discriminative.cc.

◆ StateId

typedef Arc::StateId StateId
private

Definition at line 79 of file nnet-compute-discriminative.cc.

Constructor & Destructor Documentation

◆ NnetDiscriminativeUpdater()

NnetDiscriminativeUpdater ( const AmNnet am_nnet,
const TransitionModel tmodel,
const NnetDiscriminativeUpdateOptions opts,
const DiscriminativeNnetExample eg,
Nnet nnet_to_update,
NnetDiscriminativeStats stats 
)

Definition at line 101 of file nnet-compute-discriminative.cc.

References NnetDiscriminativeUpdater::am_nnet_, NnetDiscriminativeUpdater::chunk_info_out_, Nnet::ComputeChunkInfo(), NnetDiscriminativeUpdater::eg_, AmNnet::GetNnet(), DiscriminativeNnetExample::input_frames, KALDI_ERR, MatrixBase< Real >::NumRows(), NnetDiscriminativeUpdater::opts_, NnetDiscriminativeUpdater::silence_phones_, NnetDiscriminativeUpdateOptions::silence_phones_str, and kaldi::SplitStringToIntegers().

107  :
108  am_nnet_(am_nnet), tmodel_(tmodel), opts_(opts), eg_(eg),
109  nnet_to_update_(nnet_to_update), stats_(stats) {
111  &silence_phones_)) {
112  KALDI_ERR << "Bad value for --silence-phones option: "
114  }
115  const Nnet &nnet = am_nnet_.GetNnet();
117 }
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Definition: text-utils.h:68
#define KALDI_ERR
Definition: kaldi-error.h:147
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
const NnetDiscriminativeUpdateOptions & opts_
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void ComputeChunkInfo(int32 input_chunk_size, int32 num_chunks, std::vector< ChunkInfo > *chunk_info_out) const
Uses the output of the Context() functions of the network, to compute a vector of size NumComponents(...
Definition: nnet-nnet.cc:65
const Nnet & GetNnet() const
Definition: am-nnet.h:61

Member Function Documentation

◆ Backprop()

void Backprop ( )

Definition at line 349 of file nnet-compute-discriminative.cc.

References NnetDiscriminativeUpdater::am_nnet_, Component::Backprop(), NnetDiscriminativeUpdater::backward_data_, NnetDiscriminativeUpdater::chunk_info_out_, NnetDiscriminativeUpdater::forward_data_, Nnet::GetComponent(), AmNnet::GetNnet(), NnetDiscriminativeUpdater::nnet_to_update_, and Nnet::NumComponents().

Referenced by NnetDiscriminativeUpdater::Update().

349  {
350  const Nnet &nnet = am_nnet_.GetNnet();
351  for (int32 c = nnet.NumComponents() - 1; c >= 0; c--) {
352  const Component &component = nnet.GetComponent(c);
353  Component *component_to_update = &(nnet_to_update_->GetComponent(c));
354  const CuMatrix<BaseFloat> &input = forward_data_[c],
355  &output = forward_data_[c+1],
356  &output_deriv = backward_data_;
357  CuMatrix<BaseFloat> input_deriv;
358  component.Backprop(chunk_info_out_[c], chunk_info_out_[c+1], input, output, output_deriv,
359  component_to_update, &input_deriv);
360  backward_data_.Swap(&input_deriv); // backward_data_ = input_deriv.
361  }
362 }
const Component & GetComponent(int32 c) const
Definition: nnet-nnet.cc:141
std::vector< CuMatrix< BaseFloat > > forward_data_
kaldi::int32 int32
const Nnet & GetNnet() const
Definition: am-nnet.h:61

◆ GetDiscriminativePosteriors()

double GetDiscriminativePosteriors ( Posterior post)

Assuming the lattice already has the correct scores in it, this function does the MPE or MMI forward-backward and puts the resulting discriminative posteriors (which may have positive or negative weight) into "post".

It returns, for MPFE/SMBR, the objective function, or for MMI, the negative of the denominator-lattice log-likelihood.

Definition at line 326 of file nnet-compute-discriminative.cc.

References kaldi::ConvertPosteriorToPdfs(), NnetDiscriminativeUpdateOptions::criterion, NnetDiscriminativeUpdateOptions::drop_frames, NnetDiscriminativeUpdater::eg_, KALDI_ASSERT, NnetDiscriminativeUpdater::lat_, kaldi::LatticeForwardBackwardMmi(), kaldi::LatticeForwardBackwardMpeVariants(), DiscriminativeNnetExample::num_ali, NnetDiscriminativeUpdateOptions::one_silence_class, NnetDiscriminativeUpdater::opts_, NnetDiscriminativeUpdater::silence_phones_, and NnetDiscriminativeUpdater::tmodel_.

Referenced by NnetDiscriminativeUpdater::LatticeComputations(), and NnetDiscriminativeUpdater::Update().

326  {
327  if (opts_.criterion == "mpfe" || opts_.criterion == "smbr") {
328  Posterior tid_post;
329  double ans;
333  &tid_post);
334  ConvertPosteriorToPdfs(tmodel_, tid_post, post);
335  return ans; // returns the objective function.
336  } else {
337  KALDI_ASSERT(opts_.criterion == "mmi");
338  bool convert_to_pdfs = true, cancel = true;
339  // we'll return the denominator-lattice forward backward likelihood,
340  // which is one term in the objective function.
342  opts_.drop_frames, convert_to_pdfs,
343  cancel, post);
344  }
345 }
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
BaseFloat LatticeForwardBackwardMmi(const TransitionModel &tmodel, const Lattice &lat, const std::vector< int32 > &num_ali, bool drop_frames, bool convert_to_pdf_ids, bool cancel, Posterior *post)
This function can be used to compute posteriors for MMI, with a positive contribution for the numerat...
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const NnetDiscriminativeUpdateOptions & opts_
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
Definition: posterior.cc:322

◆ GetInputFeatures()

SubMatrix< BaseFloat > GetInputFeatures ( ) const

Definition at line 121 of file nnet-compute-discriminative.cc.

References NnetDiscriminativeUpdater::am_nnet_, NnetDiscriminativeUpdater::eg_, AmNnet::GetNnet(), DiscriminativeNnetExample::input_frames, KALDI_ASSERT, DiscriminativeNnetExample::left_context, Nnet::LeftContext(), DiscriminativeNnetExample::num_ali, MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), and Nnet::RightContext().

Referenced by NnetDiscriminativeUpdater::Propagate(), and NnetDiscriminativeUpdater::Update().

121  {
122  int32 num_frames_output = eg_.num_ali.size();
123  int32 eg_left_context = eg_.left_context,
124  eg_right_context = eg_.input_frames.NumRows() -
125  num_frames_output - eg_left_context;
126  KALDI_ASSERT(eg_right_context >= 0);
127  const Nnet &nnet = am_nnet_.GetNnet();
128  // Make sure the example has enough acoustic left and right
129  // context... normally we'll use examples generated using the same model,
130  // which will have the exact context, but we enable a mismatch in context as
131  // long as it is more, not less.
132  KALDI_ASSERT(eg_left_context >= nnet.LeftContext() &&
133  eg_right_context >= nnet.RightContext());
134  int32 offset = eg_left_context - nnet.LeftContext(),
135  num_output_frames =
136  num_frames_output + nnet.LeftContext() + nnet.RightContext();
137  SubMatrix<BaseFloat> ans(eg_.input_frames, offset, num_output_frames,
138  0, eg_.input_frames.NumCols());
139  return ans;
140 }
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
kaldi::int32 int32
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
Definition: nnet-example.h:159
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
int32 left_context
The number of frames of left context in the features (we can work out the #frames of right context fr...
Definition: nnet-example.h:164
const Nnet & GetNnet() const
Definition: am-nnet.h:61

◆ GetOutput()

CuMatrixBase<BaseFloat>& GetOutput ( )
inline

Definition at line 68 of file nnet-compute-discriminative.cc.

References NnetDiscriminativeUpdater::forward_data_.

68 { return forward_data_.back(); }
std::vector< CuMatrix< BaseFloat > > forward_data_

◆ LatticeComputations()

void LatticeComputations ( )

Does the parts between Propagate() and Backprop(), that involve forward-backward over the lattice.

Definition at line 178 of file nnet-compute-discriminative.cc.

References NnetDiscriminativeUpdateOptions::acoustic_scale, NnetDiscriminativeUpdater::am_nnet_, NnetDiscriminativeUpdater::backward_data_, NnetDiscriminativeUpdateOptions::boost, fst::ConvertLattice(), NnetDiscriminativeUpdateOptions::criterion, DiscriminativeNnetExample::den_lat, VectorBase< Real >::Dim(), NnetDiscriminativeUpdater::eg_, NnetDiscriminativeUpdater::forward_data_, NnetDiscriminativeUpdater::GetDiscriminativePosteriors(), AmNnet::GetNnet(), rnnlm::i, KALDI_ASSERT, KALDI_ISINF, KALDI_ISNAN, KALDI_WARN, NnetDiscriminativeUpdater::lat_, kaldi::LatticeBoost(), kaldi::LatticeStateTimes(), kaldi::Log(), CuMatrixBase< Real >::Lookup(), NnetDiscriminativeUpdater::MakePair(), DiscriminativeNnetExample::num_ali, CuMatrixBase< Real >::NumCols(), Nnet::NumComponents(), CuMatrixBase< Real >::NumRows(), NnetDiscriminativeUpdater::opts_, AmNnet::Priors(), kaldi::ScalePosterior(), NnetDiscriminativeUpdater::silence_phones_, NnetDiscriminativeUpdater::stats_, NnetDiscriminativeUpdater::tmodel_, NnetDiscriminativeStats::tot_den_objf, NnetDiscriminativeStats::tot_num_count, NnetDiscriminativeStats::tot_num_objf, NnetDiscriminativeStats::tot_t, NnetDiscriminativeStats::tot_t_weighted, TransitionModel::TransitionIdToPdf(), DiscriminativeNnetExample::weight, and LatticeWeightTpl< BaseFloat >::Zero().

Referenced by NnetDiscriminativeUpdater::Update().

178  {
179  ConvertLattice(eg_.den_lat, &lat_); // convert to Lattice.
180  TopSort(&lat_); // Topologically sort (required by forward-backward algorithms)
181 
182  if (opts_.criterion == "mmi" && opts_.boost != 0.0) {
183  BaseFloat max_silence_error = 0.0;
185  opts_.boost, max_silence_error, &lat_);
186  }
187 
188  int32 num_frames = static_cast<int32>(eg_.num_ali.size());
189 
190  stats_->tot_t += num_frames;
191  stats_->tot_t_weighted += num_frames * eg_.weight;
192 
193  const VectorBase<BaseFloat> &priors = am_nnet_.Priors();
194  const CuMatrix<BaseFloat> &posteriors = forward_data_.back();
195 
196  KALDI_ASSERT(posteriors.NumRows() == num_frames);
197  int32 num_pdfs = posteriors.NumCols();
198  KALDI_ASSERT(num_pdfs == priors.Dim());
199 
200  // We need to look up the posteriors of some pdf-ids in the matrix
201  // "posteriors". Rather than looking them all up using operator (), which is
202  // very slow because each lookup involves a separate CUDA call with
203  // communication over PciExpress, we look them up all at once using
204  // CuMatrix::Lookup().
205  // Note: regardless of the criterion, we evaluate the likelihoods in
206  // the numerator alignment. Even though they may be irrelevant to
207  // the optimization, they will affect the value of the objective function.
208 
209  std::vector<Int32Pair> requested_indexes;
210  BaseFloat wiggle_room = 1.3; // value not critical.. it's just 'reserve'
211  requested_indexes.reserve(num_frames + wiggle_room * lat_.NumStates());
212 
213  if (opts_.criterion == "mmi") { // need numerator probabilities...
214  for (int32 t = 0; t < num_frames; t++) {
215  int32 tid = eg_.num_ali[t], pdf_id = tmodel_.TransitionIdToPdf(tid);
216  KALDI_ASSERT(pdf_id >= 0 && pdf_id < num_pdfs);
217  requested_indexes.push_back(MakePair(t, pdf_id));
218  }
219  }
220 
221  std::vector<int32> state_times;
222  int32 T = LatticeStateTimes(lat_, &state_times);
223  KALDI_ASSERT(T == num_frames);
224 
225  StateId num_states = lat_.NumStates();
226  for (StateId s = 0; s < num_states; s++) {
227  StateId t = state_times[s];
228  for (fst::ArcIterator<Lattice> aiter(lat_, s); !aiter.Done(); aiter.Next()) {
229  const Arc &arc = aiter.Value();
230  if (arc.ilabel != 0) { // input-side has transition-ids, output-side empty
231  int32 tid = arc.ilabel, pdf_id = tmodel_.TransitionIdToPdf(tid);
232  requested_indexes.push_back(MakePair(t, pdf_id));
233  }
234  }
235  }
236 
237  std::vector<BaseFloat> answers;
238  CuArray<Int32Pair> cu_requested_indexes(requested_indexes);
239  answers.resize(requested_indexes.size());
240  posteriors.Lookup(cu_requested_indexes, &(answers[0]));
241 
242  int32 num_floored = 0;
243 
244  BaseFloat floor_val = 1.0e-20; // floor for posteriors.
245  size_t index;
246 
247  // Replace "answers" with the vector of scaled log-probs. If this step takes
248  // too much time, we can look at other ways to do it, using the CUDA card.
249  for (index = 0; index < answers.size(); index++) {
250  BaseFloat post = answers[index];
251  if (post < floor_val) {
252  post = floor_val;
253  num_floored++;
254  }
255  int32 pdf_id = requested_indexes[index].second;
256  BaseFloat pseudo_loglike = Log(post / priors(pdf_id)) * opts_.acoustic_scale;
257  KALDI_ASSERT(!KALDI_ISINF(pseudo_loglike) && !KALDI_ISNAN(pseudo_loglike));
258  answers[index] = pseudo_loglike;
259  }
260  if (num_floored > 0) {
261  KALDI_WARN << "Floored " << num_floored << " probabilities from nnet.";
262  }
263 
264  index = 0;
265 
266  if (opts_.criterion == "mmi") {
267  double tot_num_like = 0.0;
268  for (; index < eg_.num_ali.size(); index++)
269  tot_num_like += answers[index];
270  stats_->tot_num_objf += eg_.weight * tot_num_like;
271  }
272 
273  // Now put the (scaled) acoustic log-likelihoods in the lattice.
274  for (StateId s = 0; s < num_states; s++) {
275  for (fst::MutableArcIterator<Lattice> aiter(&lat_, s);
276  !aiter.Done(); aiter.Next()) {
277  Arc arc = aiter.Value();
278  if (arc.ilabel != 0) { // input-side has transition-ids, output-side empty
279  arc.weight.SetValue2(-answers[index]);
280  index++;
281  aiter.SetValue(arc);
282  }
283  }
284  LatticeWeight final = lat_.Final(s);
285  if (final != LatticeWeight::Zero()) {
286  final.SetValue2(0.0); // make sure no acoustic term in final-prob.
287  lat_.SetFinal(s, final);
288  }
289  }
290  KALDI_ASSERT(index == answers.size());
291 
292  // Get the MPE or MMI posteriors.
293  Posterior post;
295 
296  ScalePosterior(eg_.weight, &post);
297 
298  double tot_num_post = 0.0, tot_den_post = 0.0;
299  std::vector<MatrixElement<BaseFloat> > sv_labels;
300  sv_labels.reserve(answers.size());
301  for (int32 t = 0; t < post.size(); t++) {
302  for (int32 i = 0; i < post[t].size(); i++) {
303  int32 pdf_id = post[t][i].first;
304  BaseFloat weight = post[t][i].second;
305  if (weight > 0.0) { tot_num_post += weight; }
306  else { tot_den_post -= weight; }
307  MatrixElement<BaseFloat> elem = {t, pdf_id, weight};
308  sv_labels.push_back(elem);
309  }
310  }
311  stats_->tot_num_count += tot_num_post;
312  int32 num_components = am_nnet_.GetNnet().NumComponents();
313  const CuMatrix<BaseFloat> &output(forward_data_[num_components]);
314  backward_data_.Resize(output.NumRows(), output.NumCols()); // zeroes it.
315 
316  { // We don't actually need tot_objf and tot_weight; we have already
317  // computed the objective function.
318  BaseFloat tot_objf, tot_weight;
319  backward_data_.CompObjfAndDeriv(sv_labels, output, &tot_objf, &tot_weight);
320  // Now backward_data_ will contan the derivative at the output.
321  // Our work here is done..
322  }
323 }
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
std::vector< CuMatrix< BaseFloat > > forward_data_
#define KALDI_ISINF
Definition: kaldi-math.h:73
static Int32Pair MakePair(int32 first, int32 second)
kaldi::int32 int32
int32 TransitionIdToPdf(int32 trans_id) const
fst::LatticeWeightTpl< BaseFloat > LatticeWeight
Definition: kaldi-lattice.h:32
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
Definition: nnet-nnet.h:69
float BaseFloat
Definition: kaldi-types.h:29
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
double Log(double x)
Definition: kaldi-math.h:100
const VectorBase< BaseFloat > & Priors() const
Definition: am-nnet.h:67
double GetDiscriminativePosteriors(Posterior *post)
Assuming the lattice already has the correct scores in it, this function does the MPE or MMI forward-...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
static const LatticeWeightTpl Zero()
CompactLattice den_lat
The denominator lattice.
Definition: nnet-example.h:148
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< int32 > num_ali
The numerator alignment.
Definition: nnet-example.h:143
BaseFloat weight
The weight we assign to this example; this will typically be one, but we include it for the sake of g...
Definition: nnet-example.h:140
void ScalePosterior(BaseFloat scale, Posterior *post)
Scales the BaseFloat (weight) element in the posterior entries.
Definition: posterior.cc:218
#define KALDI_ISNAN
Definition: kaldi-math.h:72
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
const NnetDiscriminativeUpdateOptions & opts_
bool LatticeBoost(const TransitionModel &trans, const std::vector< int32 > &alignment, const std::vector< int32 > &silence_phones, BaseFloat b, BaseFloat max_silence_error, Lattice *lat)
Boosts LM probabilities by b * [number of frame errors]; equivalently, adds -b*[number of frame error...
const Nnet & GetNnet() const
Definition: am-nnet.h:61

◆ MakePair()

static Int32Pair MakePair ( int32  first,
int32  second 
)
inlinestatic

Definition at line 70 of file nnet-compute-discriminative.cc.

References Int32Pair::first, and Int32Pair::second.

Referenced by NnetDiscriminativeUpdater::LatticeComputations().

70  {
71  Int32Pair ans;
72  ans.first = first;
73  ans.second = second;
74  return ans;
75  }
int32_cuda second
Definition: cu-matrixdim.h:80
int32_cuda first
Definition: cu-matrixdim.h:79

◆ Propagate()

void Propagate ( )

The forward-through-the-layers part of the computation.

Definition at line 142 of file nnet-compute-discriminative.cc.

References NnetDiscriminativeUpdater::am_nnet_, Component::BackpropNeedsInput(), Component::BackpropNeedsOutput(), NnetDiscriminativeUpdater::chunk_info_out_, NnetDiscriminativeUpdater::eg_, NnetDiscriminativeUpdater::forward_data_, Nnet::GetComponent(), NnetDiscriminativeUpdater::GetInputFeatures(), AmNnet::GetNnet(), NnetDiscriminativeUpdater::nnet_to_update_, MatrixBase< Real >::NumCols(), Nnet::NumComponents(), MatrixBase< Real >::NumRows(), Component::Propagate(), and DiscriminativeNnetExample::spk_info.

Referenced by NnetDiscriminativeUpdater::Update().

142  {
143  const Nnet &nnet = am_nnet_.GetNnet();
144  forward_data_.resize(nnet.NumComponents() + 1);
145 
146  SubMatrix<BaseFloat> input_feats = GetInputFeatures();
147  int32 spk_dim = eg_.spk_info.Dim();
148  if (spk_dim == 0) {
149  forward_data_[0] = input_feats;
150  } else {
151  forward_data_[0].Resize(input_feats.NumRows(),
152  input_feats.NumCols() + eg_.spk_info.Dim());
153  forward_data_[0].Range(0, input_feats.NumRows(),
154  0, input_feats.NumCols()).CopyFromMat(input_feats);
155  forward_data_[0].Range(0, input_feats.NumRows(),
156  input_feats.NumCols(), spk_dim).CopyRowsFromVec(
157  eg_.spk_info);
158  }
159 
160  for (int32 c = 0; c < nnet.NumComponents(); c++) {
161  const Component &component = nnet.GetComponent(c);
162  CuMatrix<BaseFloat> &input = forward_data_[c],
163  &output = forward_data_[c+1];
164  component.Propagate(chunk_info_out_[c] , chunk_info_out_[c+1], input, &output);
165  const Component *prev_component = (c == 0 ? NULL :
166  &(nnet.GetComponent(c-1)));
167  bool will_do_backprop = (nnet_to_update_ != NULL),
168  keep_last_output = will_do_backprop &&
169  ((c>0 && prev_component->BackpropNeedsOutput()) ||
170  component.BackpropNeedsInput());
171  if (!keep_last_output)
172  forward_data_[c].Resize(0, 0); // We won't need this data; save memory.
173  }
174 }
std::vector< CuMatrix< BaseFloat > > forward_data_
kaldi::int32 int32
Vector< BaseFloat > spk_info
spk_info contains any component of the features that varies slowly or not at all with time (and hence...
Definition: nnet-example.h:171
const Nnet & GetNnet() const
Definition: am-nnet.h:61

◆ Update()

Member Data Documentation

◆ am_nnet_

◆ backward_data_

◆ chunk_info_out_

◆ eg_

◆ forward_data_

◆ lat_

◆ nnet_to_update_

◆ opts_

◆ silence_phones_

◆ stats_

◆ tmodel_


The documentation for this class was generated from the following file: