nnet-compute-discriminative.h
Go to the documentation of this file.
1 // nnet2/nnet-compute-discriminative.h
2 
3 // Copyright 2012-2013 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET2_NNET_COMPUTE_DISCRIMINATIVE_H_
21 #define KALDI_NNET2_NNET_COMPUTE_DISCRIMINATIVE_H_
22 
23 #include "nnet2/am-nnet.h"
24 #include "nnet2/nnet-example.h"
25 #include "hmm/transition-model.h"
26 
27 namespace kaldi {
28 namespace nnet2 {
29 
30 /* This header provides functionality for doing model updates, and computing
31  gradients, using discriminative objective functions (MPFE, SMBR, MMI).
32  We use the DiscriminativeNnetExample defined in nnet-example.h.
33 */
34 
36  std::string criterion; // "mmi" or "mpfe" or "smbr"
38  bool drop_frames; // for MMI, true if we ignore frames where alignment
39  // pdf-id is not in the lattice.
40  bool one_silence_class; // Affects MPE/SMBR>
41  BaseFloat boost; // for MMI, boosting factor (would be Boosted MMI)... e.g. 0.1.
42 
43  std::string silence_phones_str; // colon-separated list of integer ids of silence phones,
44  // for MPE/SMBR only.
45 
46  NnetDiscriminativeUpdateOptions(): criterion("smbr"), acoustic_scale(0.1),
47  drop_frames(false),
48  one_silence_class(false),
49  boost(0.0) { }
50 
51  void Register(OptionsItf *opts) {
52  opts->Register("criterion", &criterion, "Criterion, 'mmi'|'mpfe'|'smbr', "
53  "determines the objective function to use. Should match "
54  "option used when we created the examples.");
55  opts->Register("acoustic-scale", &acoustic_scale, "Weighting factor to "
56  "apply to acoustic likelihoods.");
57  opts->Register("drop-frames", &drop_frames, "For MMI, if true we drop frames "
58  "with no overlap of num and den frames");
59  opts->Register("boost", &boost, "Boosting factor for boosted MMI (e.g. 0.1)");
60  opts->Register("one-silence-class", &one_silence_class, "If true, newer "
61  "behavior which will tend to reduce insertions.");
62  opts->Register("silence-phones", &silence_phones_str,
63  "For MPFE or SMBR, colon-separated list of integer ids of "
64  "silence phones, e.g. 1:2:3");
65 
66  }
67 };
68 
69 
71  double tot_t; // total number of frames
72  double tot_t_weighted; // total number of frames times weight.
73  double tot_num_count; // total count of numerator posterior (should be
74  // identical to denominator-posterior count, so we don't
75  // separately compute that).
76  double tot_num_objf; // for MMI, the (weighted) numerator likelihood; for
77  // SMBR/MPFE, 0.
78  double tot_den_objf; // for MMI, the (weighted) denominator likelihood; for
79  // SMBR/MPFE, the objective function.
80  NnetDiscriminativeStats() { std::memset(this, 0, sizeof(*this)); }
81  void Print(std::string criterion); // const NnetDiscriminativeUpdateOptions &opts);
82  void Add(const NnetDiscriminativeStats &other);
83 };
84 
104 void NnetDiscriminativeUpdate(const AmNnet &am_nnet,
105  const TransitionModel &tmodel,
107  const DiscriminativeNnetExample &eg,
108  Nnet *nnet_to_update,
109  NnetDiscriminativeStats *stats);
110 
111 
112 } // namespace nnet2
113 } // namespace kaldi
114 
115 #endif // KALDI_NNET2_NNET_COMPUTE_DISCRIMINATIVE_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
This struct is used to store the information we need for discriminative training (MMI or MPE)...
Definition: nnet-example.h:136
void Print(const Fst< Arc > &fst, std::string message)
void NnetDiscriminativeUpdate(const AmNnet &am_nnet, const TransitionModel &tmodel, const NnetDiscriminativeUpdateOptions &opts, const DiscriminativeNnetExample &eg, Nnet *nnet_to_update, NnetDiscriminativeStats *stats)
Does the neural net computation, lattice forward-backward, and backprop, for either the MMI...