discriminative-supervision.h
Go to the documentation of this file.
1 // nnet3/discriminative-supervision.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H
22 #define KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H
23 
24 #include "util/table-types.h"
25 #include "hmm/posterior.h"
26 #include "hmm/transition-model.h"
27 #include "lat/kaldi-lattice.h"
28 
29 namespace kaldi {
30 namespace discriminative {
31 
32 
39  bool minimize; // we'll push and minimize if this is true.
41 
43  remove_output_symbols(true), collapse_transition_ids(true),
44  remove_epsilons(true), determinize(true),
45  minimize(true), acoustic_scale(0.1) { }
46 
47  void Register(OptionsItf *opts) {
48  opts->Register("collapse-transition-ids", &collapse_transition_ids,
49  "If true, modify the transition-ids on denominator lattice "
50  "so that on each frame, there is just one with any given "
51  "pdf-id. This allows us to determinize and minimize "
52  "more completely.");
53  opts->Register("remove-output-symbols", &remove_output_symbols,
54  "Remove output symbols from lattice to convert it to an "
55  "acceptor and make it more determinizable");
56  opts->Register("remove-epsilons", &remove_epsilons,
57  "Remove epsilons from the split lattices");
58  opts->Register("determinize", &determinize, "If true, we determinize "
59  "lattices (as Lattice) after splitting and possibly minimize");
60  opts->Register("minimize", &minimize, "If true, we push and "
61  "minimize lattices (as Lattice) after splitting");
62  opts->Register("acoustic-scale", &acoustic_scale,
63  "Scaling factor for acoustic likelihoods (should match the "
64  "value used in discriminative-get-supervision)");
65  }
66 };
67 
68 /*
69  This file contains some declarations relating to the object we use to
70  encode the supervision information for sequence training
71 */
72 
73 // struct DiscriminativeSupervision is the fully-processed information for
74 // a whole utterance or (after splitting) part of an utterance.
76  // The weight we assign to this example;
77  // this will typically be one, but we include it
78  // for the sake of generality.
80 
81  // num_sequences will be 1 if you create a DiscriminativeSupervision object from a single
82  // lattice or alignment, but if you combine multiple DiscriminativeSupervision objects
83  // the 'num_sequences' is the number of objects that were combined (the
84  // lattices get appended).
86 
87  // the number of frames in each sequence of appended objects. num_frames *
88  // num_sequences must equal the path length of any path in the lattices.
89  // Technically this information is redundant with the lattices, but it's convenient
90  // to have it separately.
92 
93  // The numerator alignment
94  // Usually obtained by aligning the reference text with the seed neural
95  // network model; can be the best path of generated lattice in the case of
96  // semi-supervised training.
97  std::vector<int32> num_ali;
98 
99  // Note: any acoustic
100  // likelihoods in the lattices will be
101  // recomputed at the time we train.
102 
103  // The denominator lattice.
105 
106  DiscriminativeSupervision(): weight(1.0), num_sequences(1),
107  frames_per_sequence(-1) { }
108 
110 
111 
112  // This function creates a supervision object from numerator alignment
113  // and denominator lattice. The supervision object is used for sequence
114  // discriminative training.
115  // Topologically sorts the lattice after copying to the supervision object.
116  // Returns false when alignment or lattice is empty
117  bool Initialize(const std::vector<int32> &alignment,
118  const Lattice &lat,
119  BaseFloat weight);
120 
121  void Swap(DiscriminativeSupervision *other);
122 
123  bool operator == (const DiscriminativeSupervision &other) const;
124 
125  // This function checks that this supervision object satifsies some
126  // of the properties we expect of it, and calls KALDI_ERR if not.
127  void Check() const;
128 
129  inline int32 NumFrames() const {
130  return num_sequences * frames_per_sequence;
131  }
132 
133  void Write(std::ostream &os, bool binary) const;
134  void Read(std::istream &is, bool binary);
135 };
136 
137 // This class is used for splitting something of type
138 // DiscriminativeSupervision into
139 // multiple pieces corresponding to different frame-ranges.
141  public:
142  typedef fst::ArcTpl<LatticeWeight> LatticeArc;
143  typedef fst::VectorFst<LatticeArc> Lattice;
144 
147  const TransitionModel &tmodel,
148  const DiscriminativeSupervision &supervision);
149 
150  // A structure used to store the forward and backward scores
151  // and state times of a lattice
152  struct LatticeInfo {
153  // These values are stored in log.
154  std::vector<double> alpha;
155  std::vector<double> beta;
156  std::vector<int32> state_times;
157 
158  void Check() const;
159  };
160 
161  // Extracts a frame range of the supervision into 'supervision'.
162  void GetFrameRange(int32 begin_frame, int32 frames_per_sequence,
163  bool normalize,
164  DiscriminativeSupervision *supervision) const;
165 
166  // Get the acoustic scaled denominator lattice out for debugging purposes
167  inline const Lattice& DenLat() const { return den_lat_; }
168 
169  private:
170 
171  // Creates an output lattice covering frames begin_frame <= t < end_frame,
172  // assuming that the corresponding state-range that we need to
173  // include, begin_state <= s < end_state has been included.
174  // (note: the output lattice will also have two special initial and final
175  // states).
176  // Also does post-processing (RmEpsilon, Determinize,
177  // TopSort on the result). See code for details.
178  void CreateRangeLattice(const Lattice &in_lat,
179  const LatticeInfo &scores,
180  int32 begin_frame, int32 end_frame, bool normalize,
181  Lattice *out_lat) const;
182 
183  // Config options for splitting supervision object
185 
186  // Transition model is used by the function
187  // CollapseTransitionIds()
189 
190  // A reference to the supervision object that we will be splitting
192 
193  // LatticeInfo object for denominator lattice.
194  // This will be computed when PrepareLattice function is called.
196 
197  // Copy of denominator lattice. This is required because the lattice states
198  // need to be ordered in breadth-first search order.
199  Lattice den_lat_;
200 
201  // Function to compute lattice scores for a lattice
202  void ComputeLatticeScores(const Lattice &lat, LatticeInfo *scores) const;
203 
204  // Prepare lattice :
205  // 1) Order states in breadth-first search order
206  // 2) Compute states times, which must be a strictly non-decreasing vector
207  // 3) Compute lattice alpha and beta scores
208  void PrepareLattice(Lattice *lat, LatticeInfo *scores) const;
209 
210  // Modifies the transition-ids on lat_ so that on each frame, there is just
211  // one with any given pdf-id. This allows us to determinize and minimize
212  // more completely.
213  void CollapseTransitionIds(const std::vector<int32> &state_times,
214  Lattice *lat) const;
215 
216 };
217 
224 
225 void MergeSupervision(const std::vector<const DiscriminativeSupervision*> &input,
226  DiscriminativeSupervision *output_supervision);
227 
228 
229 } // namespace discriminative
230 } // namespace kaldi
231 
232 #endif // KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
const SplitDiscriminativeSupervisionOptions & config_
kaldi::int32 int32
void MergeSupervision(const std::vector< const DiscriminativeSupervision *> &input, DiscriminativeSupervision *output_supervision)
This function appends a list of supervision objects to create what will usually be a single such obje...
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
bool operator==(const LatticeWeightTpl< FloatType > &wa, const LatticeWeightTpl< FloatType > &wb)