106 Nnet *nnet_to_update,
112 KALDI_ERR <<
"Bad value for --silence-phones option: " 125 num_frames_output - eg_left_context;
154 0, input_feats.
NumCols()).CopyFromMat(input_feats);
156 input_feats.
NumCols(), spk_dim).CopyRowsFromVec(
165 const Component *prev_component = (c == 0 ? NULL :
168 keep_last_output = will_do_backprop &&
171 if (!keep_last_output)
209 std::vector<Int32Pair> requested_indexes;
211 requested_indexes.reserve(num_frames + wiggle_room *
lat_.NumStates());
214 for (
int32 t = 0; t < num_frames; t++) {
217 requested_indexes.push_back(
MakePair(t, pdf_id));
221 std::vector<int32> state_times;
226 for (
StateId s = 0; s < num_states; s++) {
228 for (fst::ArcIterator<Lattice> aiter(
lat_, s); !aiter.Done(); aiter.Next()) {
229 const Arc &arc = aiter.Value();
230 if (arc.ilabel != 0) {
232 requested_indexes.push_back(
MakePair(t, pdf_id));
237 std::vector<BaseFloat> answers;
239 answers.resize(requested_indexes.size());
240 posteriors.
Lookup(cu_requested_indexes, &(answers[0]));
242 int32 num_floored = 0;
249 for (index = 0; index < answers.size(); index++) {
251 if (post < floor_val) {
255 int32 pdf_id = requested_indexes[index].second;
258 answers[index] = pseudo_loglike;
260 if (num_floored > 0) {
261 KALDI_WARN <<
"Floored " << num_floored <<
" probabilities from nnet.";
267 double tot_num_like = 0.0;
269 tot_num_like += answers[index];
274 for (
StateId s = 0; s < num_states; s++) {
275 for (fst::MutableArcIterator<Lattice> aiter(&
lat_, s);
276 !aiter.Done(); aiter.Next()) {
277 Arc arc = aiter.Value();
278 if (arc.ilabel != 0) {
279 arc.weight.SetValue2(-answers[index]);
286 final.SetValue2(0.0);
287 lat_.SetFinal(s,
final);
298 double tot_num_post = 0.0, tot_den_post = 0.0;
299 std::vector<MatrixElement<BaseFloat> > sv_labels;
300 sv_labels.reserve(answers.size());
301 for (
int32 t = 0; t < post.size(); t++) {
302 for (
int32 i = 0;
i < post[t].size();
i++) {
303 int32 pdf_id = post[t][
i].first;
305 if (weight > 0.0) { tot_num_post += weight; }
306 else { tot_den_post -= weight; }
308 sv_labels.push_back(elem);
319 backward_data_.CompObjfAndDeriv(sv_labels, output, &tot_objf, &tot_weight);
338 bool convert_to_pdfs =
true, cancel =
true;
359 component_to_update, &input_deriv);
369 Nnet *nnet_to_update,
372 nnet_to_update, stats);
377 tot_t += other.
tot_t;
385 KALDI_ASSERT(criterion ==
"mmi" || criterion ==
"smbr" ||
386 criterion ==
"mpfe");
388 double avg_post_per_frame = tot_num_count / tot_t_weighted;
389 KALDI_LOG <<
"Number of frames is " << tot_t
390 <<
" (weighted: " << tot_t_weighted
391 <<
"), average (num or den) posterior per frame is " 392 << avg_post_per_frame;
394 if (criterion ==
"mmi") {
395 double num_objf = tot_num_objf / tot_t_weighted,
396 den_objf = tot_den_objf / tot_t_weighted,
397 objf = num_objf - den_objf;
398 KALDI_LOG <<
"MMI objective function is " << num_objf <<
" - " 399 << den_objf <<
" = " << objf <<
" per frame, over " 400 << tot_t_weighted <<
" frames.";
401 }
else if (criterion ==
"mpfe") {
402 double objf = tot_den_objf / tot_t_weighted;
404 KALDI_LOG <<
"MPFE objective function is " << objf
405 <<
" per frame, over " << tot_t_weighted <<
" frames.";
407 double objf = tot_den_objf / tot_t_weighted;
409 KALDI_LOG <<
"SMBR objective function is " << objf
410 <<
" per frame, over " << tot_t_weighted <<
" frames.";
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
const Component & GetComponent(int32 c) const
int32 LeftContext() const
Returns the left-context summed over all the Components...
fst::ArcTpl< LatticeWeight > LatticeArc
std::vector< int32 > silence_phones_
CuMatrix< BaseFloat > backward_data_
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
const DiscriminativeNnetExample & eg_
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
std::vector< CuMatrix< BaseFloat > > forward_data_
bool SplitStringToIntegers(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< I > *out)
Split a string (e.g.
Abstract class, basic element of the network, it is a box with defined inputs, outputs, and tranformation functions interface.
virtual bool BackpropNeedsInput() const
NnetDiscriminativeStats * stats_
static Int32Pair MakePair(int32 first, int32 second)
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
int32 TransitionIdToPdf(int32 trans_id) const
SubMatrix< BaseFloat > GetInputFeatures() const
std::vector< ChunkInfo > chunk_info_out_
void Add(const NnetDiscriminativeStats &other)
void Lookup(const std::vector< Int32Pair > &indexes, Real *output) const
int32 NumComponents() const
Returns number of components– think of this as similar to # of layers, but e.g.
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
BaseFloat LatticeForwardBackwardMmi(const TransitionModel &tmodel, const Lattice &lat, const std::vector< int32 > &num_ali, bool drop_frames, bool convert_to_pdf_ids, bool cancel, Posterior *post)
This function can be used to compute posteriors for MMI, with a positive contribution for the numerat...
NnetDiscriminativeUpdater(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
const VectorBase< BaseFloat > & Priors() const
int32 RightContext() const
Returns the right-context summed over all the Components...
double GetDiscriminativePosteriors(Posterior *post)
Assuming the lattice already has the correct scores in it, this function does the MPE or MMI forward-...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
std::string silence_phones_str
static const LatticeWeightTpl Zero()
Vector< BaseFloat > spk_info
spk_info contains any component of the features that varies slowly or not at all with time (and hence...
fst::VectorFst< LatticeArc > Lattice
virtual void Backprop(const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, Component *to_update, CuMatrix< BaseFloat > *in_deriv) const =0
Perform backward pass propagation of the derivative, and also either update the model (if to_update =...
BaseFloat LatticeForwardBackwardMpeVariants(const TransitionModel &trans, const std::vector< int32 > &silence_phones, const Lattice &lat, const std::vector< int32 > &num_ali, std::string criterion, bool one_silence_class, Posterior *post)
This function implements either the MPFE (minimum phone frame error) or SMBR (state-level minimum bay...
CompactLattice den_lat
The denominator lattice.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Matrix< BaseFloat > input_frames
The input data– typically with a number of frames [NumRows()] larger than labels.size(), because it includes features to the left and right as needed for the temporal context of the network.
std::vector< int32 > num_ali
The numerator alignment.
BaseFloat weight
The weight we assign to this example; this will typically be one, but we include it for the sake of g...
void ScalePosterior(BaseFloat scale, Posterior *post)
Scales the BaseFloat (weight) element in the posterior entries.
void Print(std::string criterion)
Matrix for CUDA computing.
void LatticeComputations()
Does the parts between Propagate() and Backprop(), that involve forward-backward over the lattice...
MatrixIndexT NumCols() const
This struct is used to store the information we need for discriminative training (MMI or MPE)...
virtual void Propagate(const ChunkInfo &in_info, const ChunkInfo &out_info, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const =0
Perform forward pass propagation Input->Output.
#define KALDI_ASSERT(cond)
const NnetDiscriminativeUpdateOptions & opts_
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
bool LatticeBoost(const TransitionModel &trans, const std::vector< int32 > &alignment, const std::vector< int32 > &silence_phones, BaseFloat b, BaseFloat max_silence_error, Lattice *lat)
Boosts LM probabilities by b * [number of frame errors]; equivalently, adds -b*[number of frame error...
void Propagate()
The forward-through-the-layers part of the computation.
MatrixIndexT NumRows() const
Dimensions.
Provides a vector abstraction class.
void ConvertPosteriorToPdfs(const TransitionModel &tmodel, const Posterior &post_in, Posterior *post_out)
Converts a posterior over transition-ids to be a posterior over pdf-ids.
int32 left_context
The number of frames of left context in the features (we can work out the #frames of right context fr...
CuMatrixBase< BaseFloat > & GetOutput()
void ComputeChunkInfo(int32 input_chunk_size, int32 num_chunks, std::vector< ChunkInfo > *chunk_info_out) const
Uses the output of the Context() functions of the network, to compute a vector of size NumComponents(...
Sub-matrix representation.
virtual bool BackpropNeedsOutput() const
const Nnet & GetNnet() const
void NnetDiscriminativeUpdate(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
Does the neural net computation, lattice forward-backward, and backprop, for either the MMI...
const TransitionModel & tmodel_