discriminative-supervision.cc
Go to the documentation of this file.
1 // nnet3/discriminative-supervision.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
22 #include "lat/lattice-functions.h"
23 
24 namespace kaldi {
25 namespace discriminative {
26 
27 
29  const DiscriminativeSupervision &other):
30  weight(other.weight), num_sequences(other.num_sequences),
31  frames_per_sequence(other.frames_per_sequence),
32  num_ali(other.num_ali), den_lat(other.den_lat) { }
33 
35  std::swap(weight, other->weight);
38  std::swap(num_ali, other->num_ali);
39  std::swap(den_lat, other->den_lat);
40 }
41 
43  const DiscriminativeSupervision &other) const {
44  return ( weight == other.weight &&
45  num_sequences == other.num_sequences &&
47  num_ali == other.num_ali &&
48  fst::Equal(den_lat, other.den_lat) );
49 }
50 
51 void DiscriminativeSupervision::Write(std::ostream &os, bool binary) const {
52  WriteToken(os, binary, "<DiscriminativeSupervision>");
53  WriteToken(os, binary, "<Weight>");
54  WriteBasicType(os, binary, weight);
55  WriteToken(os, binary, "<NumSequences>");
56  WriteBasicType(os, binary, num_sequences);
57  WriteToken(os, binary, "<FramesPerSeq>");
60  num_sequences > 0);
61 
62  WriteToken(os, binary, "<NumAli>");
63  WriteIntegerVector(os, binary, num_ali);
64 
65  WriteToken(os, binary, "<DenLat>");
66  if (!WriteLattice(os, binary, den_lat)) {
67  // We can't return error status from this function so we
68  // throw an exception.
69  KALDI_ERR << "Error writing denominator lattice to stream";
70  }
71 
72  WriteToken(os, binary, "</DiscriminativeSupervision>");
73 }
74 
75 void DiscriminativeSupervision::Read(std::istream &is, bool binary) {
76  ExpectToken(is, binary, "<DiscriminativeSupervision>");
77  ExpectToken(is, binary, "<Weight>");
78  ReadBasicType(is, binary, &weight);
79  ExpectToken(is, binary, "<NumSequences>");
80  ReadBasicType(is, binary, &num_sequences);
81  ExpectToken(is, binary, "<FramesPerSeq>");
82  ReadBasicType(is, binary, &frames_per_sequence);
84  num_sequences > 0);
85 
86  ExpectToken(is, binary, "<NumAli>");
87  ReadIntegerVector(is, binary, &num_ali);
88 
89  ExpectToken(is, binary, "<DenLat>");
90  {
91  Lattice *lat = NULL;
92  if (!ReadLattice(is, binary, &lat) || lat == NULL) {
93  // We can't return error status from this function so we
94  // throw an exception.
95  KALDI_ERR << "Error reading Lattice from stream";
96  }
97  den_lat = *lat;
98  delete lat;
99  TopSort(&den_lat);
100  }
101 
102  ExpectToken(is, binary, "</DiscriminativeSupervision>");
103 }
104 
105 bool DiscriminativeSupervision::Initialize(const std::vector<int32> &num_ali,
106  const Lattice &den_lat,
107  BaseFloat weight) {
108  if (num_ali.size() == 0) return false;
109  if (den_lat.NumStates() == 0) return false;
110 
111  this->weight = weight;
112  this->num_sequences = 1;
113  this->frames_per_sequence = num_ali.size();
114  this->num_ali = num_ali;
115  this->den_lat = den_lat;
116  KALDI_ASSERT(TopSort(&(this->den_lat)));
117 
118  // Checks if num frames in alignment matches lattice
119  Check();
120 
121  return true;
122 }
123 
125  int32 num_frames_subsampled = num_ali.size();
126  KALDI_ASSERT(num_frames_subsampled ==
128 
129  {
130  std::vector<int32> state_times;
131  int32 max_time = LatticeStateTimes(den_lat, &state_times);
132  KALDI_ASSERT(max_time == num_frames_subsampled);
133  }
134 }
135 
138  const TransitionModel &tmodel,
139  const DiscriminativeSupervision &supervision):
140  config_(config), tmodel_(tmodel), supervision_(supervision) {
141  if (supervision_.num_sequences != 1) {
142  KALDI_WARN << "Splitting already-reattached sequence (only expected in "
143  << "testing code)";
144  }
145 
146  KALDI_ASSERT(supervision_.num_sequences == 1); // For now, don't allow splitting already merged examples
147 
150 
151  int32 num_states = den_lat_.NumStates(),
153  KALDI_ASSERT(num_states > 0);
154  int32 start_state = den_lat_.Start();
155  // Lattice should be top-sorted and connected, so start-state must be 0.
156  KALDI_ASSERT(start_state == 0 && "Expecting start-state to be 0");
157 
158  KALDI_ASSERT(num_states == den_lat_scores_.state_times.size());
159  KALDI_ASSERT(den_lat_scores_.state_times[start_state] == 0);
160  KALDI_ASSERT(den_lat_scores_.state_times.back() == num_frames);
161 }
162 
163 // Make sure that for any given pdf-id and any given frame, the den-lat has
164 // only one transition-id mapping to that pdf-id, on the same frame.
165 // It helps us to more completely minimize the lattice. Note: we
166 // can't do this if the criterion is MPFE, because in that case the
167 // objective function will be affected by the phone-identities being
168 // different even if the pdf-ids are the same.
170  const std::vector<int32> &state_times, Lattice *lat) const {
171  typedef Lattice::StateId StateId;
172  typedef Lattice::Arc Arc;
173 
174  int32 num_frames = state_times.back(); // TODO: Check if this is always true
175  StateId num_states = lat->NumStates();
176 
177  std::vector<std::map<int32, int32> > pdf_to_tid(num_frames);
178  for (StateId s = 0; s < num_states; s++) {
179  int32 t = state_times[s];
180  for (fst::MutableArcIterator<Lattice> aiter(lat, s);
181  !aiter.Done(); aiter.Next()) {
182  KALDI_ASSERT(t >= 0 && t < num_frames);
183  Arc arc = aiter.Value();
184  KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
185  int32 pdf = tmodel_.TransitionIdToPdf(arc.ilabel);
186  if (pdf_to_tid[t].count(pdf) != 0) {
187  arc.ilabel = arc.olabel = pdf_to_tid[t][pdf];
188  aiter.SetValue(arc);
189  } else {
190  pdf_to_tid[t][pdf] = arc.ilabel;
191  }
192  }
193  }
194 }
195 
197  // Check if all the vectors are of size num_states
198  KALDI_ASSERT(state_times.size() == alpha.size() &&
199  state_times.size() == beta.size());
200 
201  // Check that the states are ordered in increasing order of state_times.
202  // This must be true since the states are in breadth-first search order.
203  KALDI_ASSERT(IsSorted(state_times));
204 }
205 
206 void DiscriminativeSupervisionSplitter::GetFrameRange(int32 begin_frame, int32 num_frames, bool normalize,
207  DiscriminativeSupervision *out_supervision) const {
208  int32 end_frame = begin_frame + num_frames;
209  // Note: end_frame is not included in the range of frames that the
210  // output supervision object covers; it's one past the end.
211  KALDI_ASSERT(num_frames > 0 && begin_frame >= 0 &&
212  begin_frame + num_frames <=
214 
217  begin_frame, end_frame, normalize,
218  &(out_supervision->den_lat));
219 
220  out_supervision->num_ali.clear();
221  std::copy(supervision_.num_ali.begin() + begin_frame,
222  supervision_.num_ali.begin() + end_frame,
223  std::back_inserter(out_supervision->num_ali));
224 
225  out_supervision->num_sequences = 1;
226  out_supervision->weight = supervision_.weight;
227  out_supervision->frames_per_sequence = num_frames;
228 
229  out_supervision->Check();
230 }
231 
233  const Lattice &in_lat, const LatticeInfo &scores,
234  int32 begin_frame, int32 end_frame, bool normalize,
235  Lattice *out_lat) const {
236  typedef Lattice::StateId StateId;
237 
238  const std::vector<int32> &state_times = scores.state_times;
239 
240  // Some checks to ensure the lattice and scores are prepared properly
241  KALDI_ASSERT(state_times.size() == in_lat.NumStates());
242  if (!in_lat.Properties(fst::kTopSorted, true))
243  KALDI_ERR << "Input lattice must be topologically sorted.";
244 
245  std::vector<int32>::const_iterator begin_iter =
246  std::lower_bound(state_times.begin(), state_times.end(), begin_frame),
247  end_iter = std::lower_bound(begin_iter,
248  state_times.end(), end_frame);
249 
250  KALDI_ASSERT(*begin_iter == begin_frame &&
251  (begin_iter == state_times.begin() ||
252  begin_iter[-1] < begin_frame));
253  // even if end_frame == supervision_.num_frames, there should be a state with
254  // that frame index.
255  KALDI_ASSERT(end_iter[-1] < end_frame &&
256  (end_iter < state_times.end() || *end_iter == end_frame));
257  StateId begin_state = begin_iter - state_times.begin(),
258  end_state = end_iter - state_times.begin();
259 
260  KALDI_ASSERT(end_state > begin_state);
261  out_lat->DeleteStates();
262  out_lat->ReserveStates(end_state - begin_state + 2);
263 
264  // Add special start state
265  StateId start_state = out_lat->AddState();
266  out_lat->SetStart(start_state);
267 
268  for (StateId i = begin_state; i < end_state; i++)
269  out_lat->AddState();
270 
271  // Add the special final-state.
272  StateId final_state = out_lat->AddState();
273  out_lat->SetFinal(final_state, LatticeWeight::One());
274 
275  for (StateId state = begin_state; state < end_state; state++) {
276  StateId output_state = state - begin_state + 1;
277  if (state_times[state] == begin_frame) {
278  // we'd like to make this an initial state, but OpenFst doesn't allow
279  // multiple initial states. Instead we add an epsilon transition to it
280  // from our actual initial state. The weight on this
281  // transition is the forward probability of the said 'initial state'
283  weight.SetValue1((normalize ? scores.beta[0] : 0.0) - scores.alpha[state]);
284  // Add negative of the forward log-probability to the graph cost score,
285  // since the acoustic scores would be changed later.
286  // Assuming that the lattice is scaled with appropriate acoustic
287  // scale.
288  // We additionally normalize using the total lattice score. Since the
289  // same score is added as normalizer to all the paths in the lattice,
290  // the relative probabilities of the paths in the lattice is not affected.
291  // Note: Doing a forward-backward on this split must result in a total
292  // score of 0 because of the normalization.
293 
294  out_lat->AddArc(start_state,
295  LatticeArc(0, 0, weight, output_state));
296  } else {
297  KALDI_ASSERT(scores.state_times[state] < end_frame);
298  }
299  for (fst::ArcIterator<Lattice> aiter(in_lat, state);
300  !aiter.Done(); aiter.Next()) {
301  const LatticeArc &arc = aiter.Value();
302  StateId nextstate = arc.nextstate;
303  if (nextstate >= end_state) {
304  // A transition to any state outside the range becomes a transition to
305  // our special final-state.
306  // The weight is just the negative of the backward log-probability +
307  // the arc cost. We again normalize with the total lattice score.
308  LatticeWeight weight;
309  //KALDI_ASSERT(scores.beta[state] < 0);
310  weight.SetValue1(arc.weight.Value1() - scores.beta[nextstate]);
311  weight.SetValue2(arc.weight.Value2());
312  // Add negative of the backward log-probability to the LM score, since
313  // the acoustic scores would be changed later.
314  // Note: We don't normalize here because that is already done with the
315  // initial cost.
316 
317  out_lat->AddArc(output_state,
318  LatticeArc(arc.ilabel, arc.olabel, weight, final_state));
319  } else {
320  StateId output_nextstate = nextstate - begin_state + 1;
321  out_lat->AddArc(output_state,
322  LatticeArc(arc.ilabel, arc.olabel, arc.weight, output_nextstate));
323  }
324  }
325  }
326 
327  // Get rid of the word labels and put the
328  // transition-ids on both sides.
329  fst::Project(out_lat, fst::PROJECT_INPUT);
330  fst::RmEpsilon(out_lat);
331 
333  CollapseTransitionIds(state_times, out_lat);
334 
335  if (config_.determinize) {
336  if (!config_.minimize) {
337  Lattice tmp_lat;
338  fst::Determinize(*out_lat, &tmp_lat);
339  std::swap(*out_lat, tmp_lat);
340  } else {
341  Lattice tmp_lat;
342  fst::Reverse(*out_lat, &tmp_lat);
343  fst::Determinize(tmp_lat, out_lat);
344  fst::Reverse(*out_lat, &tmp_lat);
345  fst::Determinize(tmp_lat, out_lat);
346  fst::RmEpsilon(out_lat);
347  }
348  }
349 
350  fst::TopSort(out_lat);
351  std::vector<int32> state_times_tmp;
352  KALDI_ASSERT(LatticeStateTimes(*out_lat, &state_times_tmp) ==
353  end_frame - begin_frame);
354 
355  // Remove the acoustic scale that was previously added
356  if (config_.acoustic_scale != 1.0) {
358  1 / config_.acoustic_scale), out_lat);
359  }
360 }
361 
363  Lattice *lat, LatticeInfo *scores) const {
364  // Scale the lattice to appropriate acoustic scale. It is important to
365  // ensure this is equal to the acoustic scale used while training. This is
366  // because, on splitting lattices, the initial and final costs are added
367  // into the graph cost.
369  if (config_.acoustic_scale != 1.0)
371  config_.acoustic_scale), lat);
372 
373  LatticeStateTimes(*lat, &(scores->state_times));
374  int32 num_states = lat->NumStates();
375  std::vector<std::pair<int32,int32> > state_time_indexes(num_states);
376  for (int32 s = 0; s < num_states; s++) {
377  state_time_indexes[s] = std::make_pair(scores->state_times[s], s);
378  }
379 
380  // Order the states based on the state times. This is stronger than just
381  // topological sort. This is required by the lattice splitting code.
382  std::sort(state_time_indexes.begin(), state_time_indexes.end());
383 
384  std::vector<int32> state_order(num_states);
385  for (int32 s = 0; s < num_states; s++) {
386  state_order[state_time_indexes[s].second] = s;
387  }
388 
389  fst::StateSort(lat, state_order);
390  ComputeLatticeScores(*lat, scores);
391 }
392 
394  LatticeInfo *scores) const {
395  LatticeStateTimes(lat, &(scores->state_times));
396  ComputeLatticeAlphasAndBetas(lat, false,
397  &(scores->alpha), &(scores->beta));
398  scores->Check();
399  // This check will fail if the lattice is not breadth-first search sorted
400 }
401 
402 void MergeSupervision(const std::vector<const DiscriminativeSupervision*> &input,
403  DiscriminativeSupervision *output_supervision) {
404  KALDI_ASSERT(!input.empty());
405  int32 num_inputs = input.size();
406  if (num_inputs == 1) {
407  *output_supervision = *(input[0]);
408  return;
409  }
410  *output_supervision = *(input[num_inputs-1]);
411  for (int32 i = num_inputs - 2; i >= 0; i--) {
412  const DiscriminativeSupervision &src = *(input[i]);
413  KALDI_ASSERT(src.num_sequences == 1);
414  if (output_supervision->weight == src.weight &&
415  output_supervision->frames_per_sequence ==
416  src.frames_per_sequence) {
417  // Combine with current output
418  // append src.den_lat to output_supervision->den_lat.
419  fst::Concat(src.den_lat, &output_supervision->den_lat);
420 
421  output_supervision->num_ali.insert(
422  output_supervision->num_ali.begin(),
423  src.num_ali.begin(), src.num_ali.end());
424 
425  output_supervision->num_sequences++;
426  } else {
427  KALDI_ERR << "Mismatch weight or frames_per_sequence between inputs";
428  }
429  }
430  DiscriminativeSupervision &out_sup = *output_supervision;
431  fst::TopSort(&(out_sup.den_lat));
432  out_sup.Check();
433 }
434 
435 } // namespace discriminative
436 } // namespace kaldi
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void GetFrameRange(int32 begin_frame, int32 frames_per_sequence, bool normalize, DiscriminativeSupervision *supervision) const
bool ReadLattice(std::istream &is, bool binary, Lattice **lat)
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
void ComputeLatticeScores(const Lattice &lat, LatticeInfo *scores) const
double ComputeLatticeAlphasAndBetas(const LatticeType &lat, bool viterbi, vector< double > *alpha, vector< double > *beta)
static const LatticeWeightTpl One()
void CreateRangeLattice(const Lattice &in_lat, const LatticeInfo &scores, int32 begin_frame, int32 end_frame, bool normalize, Lattice *out_lat) const
const SplitDiscriminativeSupervisionOptions & config_
void Write(std::ostream &os, bool binary) const
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
Lattice::StateId StateId
void CollapseTransitionIds(const std::vector< int32 > &state_times, Lattice *lat) const
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
kaldi::int32 int32
void MergeSupervision(const std::vector< const DiscriminativeSupervision *> &input, DiscriminativeSupervision *output_supervision)
This function appends a list of supervision objects to create what will usually be a single such obje...
int32 TransitionIdToPdf(int32 trans_id) const
bool operator==(const DiscriminativeSupervision &other) const
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
const size_t count
void PrepareLattice(Lattice *lat, LatticeInfo *scores) const
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
Definition: io-funcs-inl.h:232
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t)
bool IsSorted(const std::vector< T > &vec)
Returns true if the vector is sorted.
Definition: stl-utils.h:47
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool Initialize(const std::vector< int32 > &alignment, const Lattice &lat, BaseFloat weight)
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
Definition: io-funcs-inl.h:198
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
DiscriminativeSupervisionSplitter(const SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const DiscriminativeSupervision &supervision)