kws-functions.cc
Go to the documentation of this file.
1 // kws/kws-functions.cc
2 
3 // Copyright 2012 Johns Hopkins University (Author: Guoguo Chen)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include <algorithm>
21 
22 #include "lat/lattice-functions.h"
23 #include "kws/kws-functions.h"
26 
27 // note: this .cc file does not include everything declared in kws-functions.h;
28 // the remainder are defined in kws-functions2.cc (for compilation speed and
29 // to avoid generating too-large object files on cygwin).
30 
31 namespace kaldi {
32 
33 bool CompareInterval(const Interval &i1,
34  const Interval &i2) {
35  return (i1.Start() < i2.Start() ? true :
36  i1.Start() > i2.Start() ? false:
37  i1.End() < i2.End() ? true: false);
38 }
39 
41  const std::vector<int32> &state_times) {
42  using namespace fst;
44 
45  // Hashmap to store the cluster heads.
46  unordered_map<StateId, std::vector<Interval> > head;
47 
48  // Step 1: Iterate over the lattice to get the arcs
49  StateId max_id = 0;
50  for (StateIterator<CompactLattice> siter(*clat); !siter.Done();
51  siter.Next()) {
52  StateId state_id = siter.Value();
53  for (ArcIterator<CompactLattice> aiter(*clat, state_id); !aiter.Done();
54  aiter.Next()) {
55  CompactLatticeArc arc = aiter.Value();
56  if (state_id >= state_times.size() || arc.nextstate >= state_times.size())
57  return false;
58  if (state_id > max_id)
59  max_id = state_id;
60  if (arc.nextstate > max_id)
61  max_id = arc.nextstate;
62  head[arc.ilabel].push_back(Interval(state_times[state_id],
63  state_times[arc.nextstate]));
64  }
65  }
66  // Check if alignments and the states match
67  if (state_times.size() != max_id+1)
68  return false;
69 
70  // Step 2: Iterates over the hashmap to get the cluster heads.
71  // We sort all the words on their start-time, and the process for getting
72  // the cluster heads is to take the first one as a cluster head; then go
73  // till we find the next one that doesn't overlap in time with the current
74  // cluster head, and so on.
75  unordered_map<StateId, std::vector<Interval> >::iterator iter;
76  for (iter = head.begin(); iter != head.end(); ++iter) {
77  // For this ilabel, sort all the arcs on time, from first to last.
78  sort(iter->second.begin(), iter->second.end(), CompareInterval);
79  std::vector<Interval> tmp;
80  tmp.push_back(iter->second[0]);
81  for (int32 i = 1; i < iter->second.size(); i++) {
82  if (tmp.back().End() <= iter->second[i].Start())
83  tmp.push_back(iter->second[i]);
84  }
85  iter->second = tmp;
86  }
87 
88  // Step 3: Cluster arcs according to the maximum overlap: attach
89  // each arc to the cluster-head (as identified in Step 2) which
90  // has the most temporal overlap with the current arc.
91  for (StateIterator<CompactLattice> siter(*clat); !siter.Done();
92  siter.Next()) {
93  CompactLatticeArc::StateId state_id = siter.Value();
94  for (MutableArcIterator<CompactLattice> aiter(clat, state_id);
95  !aiter.Done(); aiter.Next()) {
96  CompactLatticeArc arc = aiter.Value();
97  // We don't cluster the epsilon arcs
98  if (arc.ilabel == 0)
99  continue;
100  // We cluster the non-epsilon arcs
101  Interval interval(state_times[state_id], state_times[arc.nextstate]);
102  int32 max_overlap = 0;
103  size_t olabel = 1;
104  for (int32 i = 0; i < head[arc.ilabel].size(); i++) {
105  int32 overlap = interval.Overlap(head[arc.ilabel][i]);
106  if (overlap > max_overlap) {
107  max_overlap = overlap;
108  olabel = i + 1; // need non-epsilon label.
109  }
110  }
111  arc.olabel = olabel;
112  aiter.SetValue(arc);
113  }
114  }
115 
116  return true;
117 }
118 
119 
121  public:
126 
128 
129  ToArc operator()(const FromArc &arc) const {
130  return ToArc(arc.ilabel,
131  arc.olabel,
132  (arc.weight == FromWeight::Zero() ?
133  ToWeight::Zero() :
134  ToWeight(arc.weight.Weight().Value1()
135  +arc.weight.Weight().Value2(),
136  (arc.weight.Weight() == LatticeWeight::Zero() ?
137  StdXStdprimeWeight::Zero() :
138  StdXStdprimeWeight::One()))),
139  arc.nextstate);
140  }
141 
142  fst::MapFinalAction FinalAction() const {
143  return fst::MAP_NO_SUPERFINAL;
144  }
145 
146  fst::MapSymbolsAction InputSymbolsAction() const {
147  return fst::MAP_COPY_SYMBOLS;
148  }
149 
150  fst::MapSymbolsAction OutputSymbolsAction() const {
151  return fst::MAP_COPY_SYMBOLS;
152  }
153 
154  uint64 Properties(uint64 props) const {
155  return props;
156  }
157 };
158 
159 
161  const std::vector<int32> &state_times,
162  int32 utterance_id,
163  KwsProductFst *factor_transducer) {
164  using namespace fst;
166 
167  // We first compute the alphas and betas
168  bool success = false;
169  std::vector<double> alpha;
170  std::vector<double> beta;
171  success = ComputeCompactLatticeAlphas(clat, &alpha);
172  success = success && ComputeCompactLatticeBetas(clat, &beta);
173  if (!success)
174  return false;
175 
176  // Now we map the CompactLattice to VectorFst<KwsProductArc>. We drop the
177  // alignment information and only keep the negated log-probs
178  Map(clat, factor_transducer, CompactLatticeToKwsProductFstMapper());
179 
180  // Now do the weight pushing manually on the CompactLattice format. Note that
181  // the alphas and betas in Kaldi are stored as the log-probs, not the negated
182  // log-probs, so the equation for weight pushing is a little different from
183  // the original algorithm (pay attention to the sign). We push the weight to
184  // initial and remove the total weight, i.e., the sum of all the outgoing
185  // transitions and final weight at any state is equal to One() (push only the
186  // negated log-prob, not the alignments)
187  for (StateIterator<KwsProductFst>
188  siter(*factor_transducer); !siter.Done(); siter.Next()) {
189  KwsProductArc::StateId state_id = siter.Value();
190  for (MutableArcIterator<KwsProductFst>
191  aiter(factor_transducer, state_id); !aiter.Done(); aiter.Next()) {
192  KwsProductArc arc = aiter.Value();
193  BaseFloat w = arc.weight.Value1().Value();
194  w += beta[state_id] - beta[arc.nextstate];
195  KwsProductWeight weight(w, arc.weight.Value2());
196  arc.weight = weight;
197  aiter.SetValue(arc);
198  }
199  // Weight of final state
200  if (factor_transducer->Final(state_id) != KwsProductWeight::Zero()) {
201  BaseFloat w = factor_transducer->Final(state_id).Value1().Value();
202  w += beta[state_id];
203  KwsProductWeight weight(w, factor_transducer->Final(state_id).Value2());
204  factor_transducer->SetFinal(state_id, weight);
205  }
206  }
207 
208  // Modify the alphas and set betas to zero. After that, we get the alphas and
209  // betas for the pushed FST. Since I will not use beta anymore, here I don't
210  // set them to zero. This can be derived from the weight pushing formula.
211  for (int32 s = 0; s < alpha.size(); s++) {
212  alpha[s] += beta[s] - beta[0];
213 
214  if (alpha[s] > 0.1) {
215  KALDI_WARN << "Positive alpha " << alpha[s];
216  }
217  }
218 
219  // to understand the next part, look at the comment in
220  // ../kwsbin/lattice-to-kws-index.cc just above the call to
221  // EnsureEpsilonProperty(). We use the bool has_epsilon_property mainly to
222  // handle the case when someone comments out that call. It should always be
223  // true in the normal case.
224  std::vector<char> state_properties;
225  ComputeStateInfo(*factor_transducer, &state_properties);
226  bool has_epsilon_property = true;
227  for (size_t i = 0; i < state_properties.size(); i++) {
228  char c = state_properties[i];
229  if ((c & kStateHasEpsilonArcsEntering) != 0 &&
231  has_epsilon_property = false;
232  if ((c & kStateHasEpsilonArcsLeaving) != 0 &&
234  has_epsilon_property = false;
235  }
236  if (!has_epsilon_property) {
237  KALDI_WARN << "Epsilon property does not hold, reverting to old behavior.";
238  }
239 
240  // OK, after the above preparation, we finally come to the factor generation
241  // step.
242  StateId ns = factor_transducer->NumStates();
243  StateId ss = factor_transducer->AddState();
244  StateId fs = factor_transducer->AddState();
245  factor_transducer->SetStart(ss);
246  factor_transducer->SetFinal(fs, KwsProductWeight::One());
247 
248  for (StateId s = 0; s < ns; s++) {
249  // Add arcs from initial state to current state
250  if (!has_epsilon_property ||
251  (state_properties[s] & kStateHasNonEpsilonArcsLeaving))
252  factor_transducer->AddArc(ss, KwsProductArc(0, 0, KwsProductWeight(-alpha[s], StdXStdprimeWeight(state_times[s], ArcticWeight::One())), s));
253  // Add arcs from current state to final state
254  if (!has_epsilon_property ||
255  (state_properties[s] & kStateHasNonEpsilonArcsEntering))
256  factor_transducer->AddArc(s, KwsProductArc(0, utterance_id, KwsProductWeight(0, StdXStdprimeWeight(TropicalWeight::One(), state_times[s])), fs));
257  // The old final state is not final any more
258  if (factor_transducer->Final(s) != KwsProductWeight::Zero())
259  factor_transducer->SetFinal(s, KwsProductWeight::Zero());
260  }
261 
262  return true;
263 }
264 
265 void RemoveLongSilences(int32 max_silence_frames,
266  const std::vector<int32> &state_times,
267  KwsProductFst *factor_transducer) {
268  using namespace fst;
270 
271  StateId ns = factor_transducer->NumStates();
272  StateId ss = factor_transducer->Start();
273  StateId bad_state = factor_transducer->AddState();
274  for (StateId s = 0; s < ns; s++) {
275  // Skip arcs start from the initial state
276  if (s == ss)
277  continue;
278  for (MutableArcIterator<KwsProductFst>
279  aiter(factor_transducer, s); !aiter.Done(); aiter.Next()) {
280  KwsProductArc arc = aiter.Value();
281  // Skip arcs end with the final state
282  if (factor_transducer->Final(arc.nextstate) != KwsProductWeight::Zero())
283  continue;
284  // Non-silence arcs
285  if (arc.ilabel != 0)
286  continue;
287  // Short silence arcs
288  if (state_times[arc.nextstate]-state_times[s] <= max_silence_frames)
289  continue;
290  // The rest are the long silence arcs, we point their nextstate to
291  // bad_state
292  arc.nextstate = bad_state;
293  aiter.SetValue(arc);
294  }
295  }
296 
297  // Trim the unsuccessful paths
298  Connect(factor_transducer);
299 }
300 
301 
302 template<class Arc>
303 static void DifferenceWrapper(const fst::VectorFst<Arc> &fst1,
304  const fst::VectorFst<Arc> &fst2,
305  fst::VectorFst<Arc> *difference) {
306  using namespace fst;
307  if (!fst2.Properties(kAcceptor, true)) {
308  // make it an acceptor by encoding the weights.
309  EncodeMapper<Arc> encoder(kEncodeLabels, ENCODE);
310  VectorFst<Arc> fst1_copy(fst1);
311  VectorFst<Arc> fst2_copy(fst2);
312  Encode(&fst1_copy, &encoder);
313  Encode(&fst2_copy, &encoder);
314  DifferenceWrapper(fst1_copy, fst2_copy, difference);
315  Decode(difference, encoder);
316  } else {
317  VectorFst<Arc> fst2_copy(fst2);
318  RmEpsilon(&fst2_copy); // or Difference will crash.
319  RemoveWeights(&fst2_copy); // or Difference will crash.
320  Difference(fst1, fst2_copy, difference);
321  }
322 }
323 
324 
325 void MaybeDoSanityCheck(const KwsLexicographicFst &index_transducer) {
327  if (GetVerboseLevel() < 2) return;
328  KwsLexicographicFst temp_transducer;
329  ShortestPath(index_transducer, &temp_transducer);
330  std::vector<Label> isymbols, osymbols;
331  KwsLexicographicWeight weight;
332  GetLinearSymbolSequence(temp_transducer, &isymbols, &osymbols, &weight);
333  std::ostringstream os;
334  for (size_t i = 0; i < isymbols.size(); i++)
335  os << isymbols[i] << ' ';
336  BaseFloat best_cost = weight.Value1().Value();
337  KALDI_VLOG(3) << "Best path: " << isymbols.size() << " isymbols " << ", "
338  << osymbols.size() << " osymbols, isymbols are " << os.str()
339  << ", best cost is " << best_cost;
340 
341  // Now get second-best path. This will exclude the best path, which
342  // will generally correspond to the empty word sequence (there will
343  // be isymbols and osymbols anyway though, because of the utterance-id
344  // having been encoded as an osymbol (and later, the EncodeFst turning it
345  // into a transducer).
346  KwsLexicographicFst difference_transducer;
347  DifferenceWrapper(index_transducer, temp_transducer, &difference_transducer);
348  ShortestPath(difference_transducer, &temp_transducer);
349 
350  GetLinearSymbolSequence(temp_transducer, &isymbols, &osymbols, &weight);
351  std::ostringstream os2;
352  for (size_t i = 0; i < isymbols.size(); i++)
353  os2 << isymbols[i] << ' ';
354  BaseFloat second_best_cost = weight.Value1().Value();
355  KALDI_VLOG(3) << "Second-best path: " << isymbols.size()
356  << " isymbols " << ", "
357  << osymbols.size() << " osymbols, isymbols are " << os2.str()
358  << ", second-best cost is " << second_best_cost;
359  if (second_best_cost < -0.01) {
360  KALDI_WARN << "Negative second-best cost found " << second_best_cost;
361  }
362 }
363 
364 
365 void MaybeDoSanityCheck(const KwsProductFst &product_transducer) {
366  if (GetVerboseLevel() < 2) return;
367  KwsLexicographicFst index_transducer;
368 
369  Map(product_transducer,
370  &index_transducer,
372 
373  MaybeDoSanityCheck(index_transducer);
374 }
375 
376 } // end namespace kaldi
fst::StdArc::StateId StateId
fst::StdArc::Label Label
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
fst::VectorFst< KwsProductArc > KwsProductFst
Definition: kaldi-kws.h:49
bool CompareInterval(const Interval &i1, const Interval &i2)
ToArc operator()(const FromArc &arc) const
fst::MapFinalAction FinalAction() const
bool ClusterLattice(CompactLattice *clat, const std::vector< int32 > &state_times)
Lattice::StateId StateId
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
static void DifferenceWrapper(const fst::VectorFst< Arc > &fst1, const fst::VectorFst< Arc > &fst2, fst::VectorFst< Arc > *difference)
kaldi::int32 int32
LogXStdXStdprimeWeight KwsProductWeight
Definition: kaldi-kws.h:47
fst::MapSymbolsAction OutputSymbolsAction() const
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
void ComputeStateInfo(const VectorFst< Arc > &fst, std::vector< char > *epsilon_info)
This function will set epsilon_info to have size equal to the NumStates() of the FST, containing a logical-or of the enum values kStateHasEpsilonArcsEntering, kStateHasNonEpsilonArcsEntering, kStateHasEpsilonArcsLeaving, and kStateHasNonEpsilonArcsLeaving.
int32 Overlap(Interval interval)
Definition: kws-functions.h:40
fst::VectorFst< KwsLexicographicArc > KwsLexicographicFst
Definition: kaldi-kws.h:46
fst::ProductWeight< TropicalWeight, ArcticWeight > StdXStdprimeWeight
Definition: kaldi-kws.h:39
static const ArcticWeightTpl< T > One()
Definition: arctic-weight.h:48
int32 Start() const
Definition: kws-functions.h:44
bool ComputeCompactLatticeBetas(const CompactLattice &clat, vector< double > *beta)
static const LatticeWeightTpl Zero()
bool CreateFactorTransducer(const CompactLattice &clat, const std::vector< int32 > &state_times, int32 utterance_id, KwsProductFst *factor_transducer)
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::StdArc::Label Label
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
void MaybeDoSanityCheck(const KwsLexicographicFst &index_transducer)
void RemoveLongSilences(int32 max_silence_frames, const std::vector< int32 > &state_times, KwsProductFst *factor_transducer)
StdLStdLStdWeight KwsLexicographicWeight
Definition: kaldi-kws.h:44
void RemoveWeights(MutableFst< Arc > *ifst)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
LogXStdXStdprimeArc KwsProductArc
Definition: kaldi-kws.h:48
fst::ArcTpl< CompactLatticeWeight > CompactLatticeArc
Definition: kaldi-lattice.h:42
bool ComputeCompactLatticeAlphas(const CompactLattice &clat, vector< double > *alpha)
int32 End() const
Definition: kws-functions.h:45
fst::MapSymbolsAction InputSymbolsAction() const