table-matcher.h
Go to the documentation of this file.
1 // fstext/table-matcher.h
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_FSTEXT_TABLE_MATCHER_H_
21 #define KALDI_FSTEXT_TABLE_MATCHER_H_
22 #include <fst/fstlib.h>
23 #include <fst/fst-decl.h>
24 
25 
26 
27 namespace fst {
28 
29 
41 
43  float table_ratio; // we construct the table if it would be at least this full.
45  TableMatcherOptions(): table_ratio(0.25), min_table_size(4) { }
46 };
47 
48 
49 // Introducing an "impl" class for TableMatcher because
50 // we need to do a shallow copy of the Matcher for when
51 // we want to cache tables for multiple compositions.
52 template<class F, class BackoffMatcher = SortedMatcher<F> >
53 class TableMatcherImpl : public MatcherBase<typename F::Arc> {
54  public:
55  typedef F FST;
56  typedef typename F::Arc Arc;
57  typedef typename Arc::Label Label;
58  typedef typename Arc::StateId StateId;
59  typedef StateId ArcId; // Use this type to store arc offsets [it's actually size_t
60  // in the Seek function of ArcIterator, but StateId should be big enough].
61  typedef typename Arc::Weight Weight;
62 
63 
64  public:
65 
66  TableMatcherImpl(const FST &fst, MatchType match_type,
68  match_type_(match_type),
69  fst_(fst.Copy()),
70  loop_(match_type == MATCH_INPUT ?
71  Arc(kNoLabel, 0, Weight::One(), kNoStateId) :
72  Arc(0, kNoLabel, Weight::One(), kNoStateId)),
73  aiter_(NULL),
74  s_(kNoStateId), opts_(opts),
75  backoff_matcher_(fst, match_type)
76  {
77  assert(opts_.min_table_size > 0);
78  if (match_type == MATCH_INPUT)
79  assert(fst_->Properties(kILabelSorted, true) == kILabelSorted);
80  else if (match_type == MATCH_OUTPUT)
81  assert(fst_->Properties(kOLabelSorted, true) == kOLabelSorted);
82  else
83  assert(0 && "Invalid FST properties");
84  }
85 
86  virtual const FST &GetFst() const { return *fst_; }
87 
88  virtual ~TableMatcherImpl() {
89  std::vector<ArcId> *const empty = ((std::vector<ArcId>*)(NULL)) + 1; // special marker.
90  for (size_t i = 0; i < tables_.size(); i++) {
91  if (tables_[i] != NULL && tables_[i] != empty)
92  delete tables_[i];
93  }
94  delete aiter_;
95  delete fst_;
96  }
97 
98  virtual MatchType Type(bool test) const {
99  return match_type_;
100  }
101 
102  void SetState(StateId s) {
103  if (aiter_) {
104  delete aiter_;
105  aiter_ = NULL;
106  }
107  if (match_type_ == MATCH_NONE)
108  LOG(FATAL) << "TableMatcher: bad match type";
109  s_ = s;
110  std::vector<ArcId> *const empty = ((std::vector<ArcId>*)(NULL)) + 1; // special marker.
111  if (static_cast<size_t>(s) >= tables_.size()) {
112  assert(s>=0);
113  tables_.resize(s+1, NULL);
114  }
115  std::vector<ArcId>* &this_table_ = tables_[s]; // note: ref to ptr.
116  if (this_table_ == empty) {
117  backoff_matcher_.SetState(s);
118  return;
119  } else if (this_table_ == NULL) { // NULL means has not been set.
120  ArcId num_arcs = fst_->NumArcs(s);
121  if (num_arcs == 0 || num_arcs < opts_.min_table_size) {
122  this_table_ = empty;
123  backoff_matcher_.SetState(s);
124  return;
125  }
126  ArcIterator<FST> aiter(*fst_, s);
127  aiter.SetFlags(kArcNoCache|(match_type_ == MATCH_OUTPUT?kArcOLabelValue:kArcILabelValue),
128  kArcNoCache|kArcValueFlags);
129  // the statement above, says: "Don't cache stuff; and I only need the ilabel/olabel
130  // to be computed.
131  aiter.Seek(num_arcs - 1);
132  Label highest_label = (match_type_ == MATCH_OUTPUT ?
133  aiter.Value().olabel : aiter.Value().ilabel);
134  if ((highest_label+1) * opts_.table_ratio > num_arcs) {
135  this_table_ = empty;
136  backoff_matcher_.SetState(s);
137  return; // table would be too sparse.
138  }
139  // OK, now we are creating the table.
140  this_table_ = new std::vector<ArcId> (highest_label+1, kNoStateId);
141  ArcId pos = 0;
142  for (aiter.Seek(0); !aiter.Done(); aiter.Next(), pos++) {
143  Label label = (match_type_ == MATCH_OUTPUT ?
144  aiter.Value().olabel : aiter.Value().ilabel);
145  assert((size_t)label <= (size_t)highest_label); // also checks >= 0.
146  if ((*this_table_)[label] == kNoStateId) (*this_table_)[label] = pos;
147  // set this_table_[label] to first position where arc has this
148  // label.
149  }
150  }
151  // At this point in the code, this_table_ != NULL and != empty.
152  aiter_ = new ArcIterator<FST>(*fst_, s);
153  aiter_->SetFlags(kArcNoCache, kArcNoCache); // don't need to cache arcs as may only
154  // need a small subset.
155  loop_.nextstate = s;
156  // aiter_ = NULL;
157  // backoff_matcher_.SetState(s);
158  }
159 
160  bool Find(Label match_label) {
161  if (!aiter_) return backoff_matcher_.Find(match_label);
162  else {
163  match_label_ = match_label;
164  current_loop_ = (match_label == 0);
165  // kNoLabel means the implicit loop on the other FST --
166  // matches real epsilons but not the self-loop.
167  match_label_ = (match_label_ == kNoLabel ? 0 : match_label_);
168  if (static_cast<size_t>(match_label_) < tables_[s_]->size() &&
169  (*(tables_[s_]))[match_label_] != kNoStateId) {
170  aiter_->Seek( (*(tables_[s_]))[match_label_] ); // label exists.
171  return true;
172  }
173  return current_loop_;
174  }
175  }
176  const Arc& Value() const {
177  if (aiter_)
178  return current_loop_ ? loop_ : aiter_->Value();
179  else
180  return backoff_matcher_.Value();
181  }
182 
183  void Next() {
184  if (aiter_) {
185  if (current_loop_)
186  current_loop_ = false;
187  else
188  aiter_->Next();
189  } else
190  backoff_matcher_.Next();
191  }
192 
193  bool Done() const {
194  if (aiter_ != NULL) {
195  if (current_loop_)
196  return false;
197  if (aiter_->Done())
198  return true;
199  Label label = (match_type_ == MATCH_OUTPUT ?
200  aiter_->Value().olabel : aiter_->Value().ilabel);
201  return (label != match_label_);
202  } else
203  return backoff_matcher_.Done();
204  }
205  const Arc &Value() {
206  if (aiter_ != NULL) {
207  return (current_loop_ ? loop_ : aiter_->Value() );
208  } else
209  return backoff_matcher_.Value();
210  }
211 
212  virtual TableMatcherImpl<FST> *Copy(bool safe = false) const {
213  assert(0); // shouldn't be called. This is not a "real" matcher,
214  // although we derive from MatcherBase for convenience.
215  return NULL;
216  }
217 
218  virtual uint64 Properties(uint64 props) const { return props; } // simple matcher that does
219  // not change its FST, so properties are properties of FST it is applied to
220 
221  private:
222  virtual void SetState_(StateId s) { SetState(s); }
223  virtual bool Find_(Label label) { return Find(label); }
224  virtual bool Done_() const { return Done(); }
225  virtual const Arc& Value_() const { return Value(); }
226  virtual void Next_() { Next(); }
227 
228  MatchType match_type_;
229  FST *fst_;
232  Arc loop_;
233  ArcIterator<FST> *aiter_;
234  StateId s_;
235  std::vector<std::vector<ArcId> *> tables_;
237  BackoffMatcher backoff_matcher_;
238 
239 };
240 
241 
242 template<class F, class BackoffMatcher = SortedMatcher<F> >
243 class TableMatcher : public MatcherBase<typename F::Arc> {
244  public:
245  typedef F FST;
246  typedef typename F::Arc Arc;
247  typedef typename Arc::Label Label;
248  typedef typename Arc::StateId StateId;
249  typedef StateId ArcId; // Use this type to store arc offsets [it's actually size_t
250  // in the Seek function of ArcIterator, but StateId should be big enough].
251  typedef typename Arc::Weight Weight;
253 
254  TableMatcher(const FST &fst, MatchType match_type,
256  : impl_(std::make_shared<Impl>(fst, match_type, opts)) { }
257 
259  bool safe = false)
260  : impl_(matcher.impl_) {
261  if (safe == true) {
262  LOG(FATAL) << "TableMatcher: Safe copy not supported";
263  }
264  }
265 
266  virtual const FST &GetFst() const { return impl_->GetFst(); }
267 
268  virtual MatchType Type(bool test) const { return impl_->Type(test); }
269 
270  void SetState(StateId s) { return impl_->SetState(s); }
271 
272  bool Find(Label match_label) { return impl_->Find(match_label); }
273 
274  const Arc& Value() const { return impl_->Value(); }
275 
276  void Next() { return impl_->Next(); }
277 
278  bool Done() const { return impl_->Done(); }
279 
280  const Arc &Value() { return impl_->Value(); }
281 
282  virtual TableMatcher<FST, BackoffMatcher> *Copy(bool safe = false) const {
283  return new TableMatcher<FST, BackoffMatcher> (*this, safe);
284  }
285 
286  virtual uint64 Properties(uint64 props) const { return impl_->Properties(props); } // simple matcher that does
287  // not change its FST, so properties are properties of FST it is applied to
288  private:
289  std::shared_ptr<Impl> impl_;
290 
291  virtual void SetState_(StateId s) { impl_->SetState(s); }
292  virtual bool Find_(Label label) { return impl_->Find(label); }
293  virtual bool Done_() const { return impl_->Done(); }
294  virtual const Arc& Value_() const { return impl_->Value(); }
295  virtual void Next_() { impl_->Next(); }
296 
297  TableMatcher &operator=(const TableMatcher &) = delete;
298 };
299 
301  bool connect; // Connect output
302  ComposeFilter filter_type; // Which pre-defined filter to use
303  MatchType table_match_type;
304 
306  bool c = true, ComposeFilter ft = SEQUENCE_FILTER,
307  MatchType tms = MATCH_OUTPUT)
308  : TableMatcherOptions(mo), connect(c), filter_type(ft), table_match_type(tms) { }
309  TableComposeOptions() : connect(true), filter_type(SEQUENCE_FILTER),
310  table_match_type(MATCH_OUTPUT) { }
311 };
312 
313 
314 template<class Arc>
315 void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
316  MutableFst<Arc> *ofst,
317  const TableComposeOptions &opts = TableComposeOptions()) {
318  typedef Fst<Arc> F;
319  CacheOptions nopts;
320  nopts.gc_limit = 0; // Cache only the last state for fastest copy.
321  if (opts.table_match_type == MATCH_OUTPUT) {
322  // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
323  ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
324  impl_opts.matcher1 = new TableMatcher<F>(ifst1, MATCH_OUTPUT, opts);
325  *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
326  } else {
327  assert(opts.table_match_type == MATCH_INPUT) ;
328  // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
329  ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
330  impl_opts.matcher2 = new TableMatcher<F>(ifst2, MATCH_INPUT, opts);
331  *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
332  }
333  if (opts.connect) Connect(ofst);
334 }
335 
336 
339 template<class F>
343  TableComposeCache(const TableComposeOptions &opts = TableComposeOptions()): matcher (NULL), opts(opts) {}
344  ~TableComposeCache() { delete(matcher); }
345 };
346 
347 template<class Arc>
348 void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
349  MutableFst<Arc> *ofst,
350  TableComposeCache<Fst<Arc> > *cache) {
351  typedef Fst<Arc> F;
352  assert(cache != NULL);
353  CacheOptions nopts;
354  nopts.gc_limit = 0; // Cache only the last state for fastest copy.
355  if (cache->opts.table_match_type == MATCH_OUTPUT) {
356  ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
357  if (cache->matcher == NULL)
358  cache->matcher = new TableMatcher<F>(ifst1, MATCH_OUTPUT, cache->opts);
359  impl_opts.matcher1 = cache->matcher->Copy(); // not passing "safe": may not
360  // be thread-safe-- anway I don't understand this part.
361  *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
362  } else {
363  assert(cache->opts.table_match_type == MATCH_INPUT) ;
364  ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
365  if (cache->matcher == NULL)
366  cache->matcher = new TableMatcher<F>(ifst2, MATCH_INPUT, cache->opts);
367  impl_opts.matcher2 = cache->matcher->Copy();
368  *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
369  }
370  if (cache->opts.connect) Connect(ofst);
371 }
372 
373 
374 
375 } // end namespace fst
376 #endif
fst::StdArc::StateId StateId
TableMatcher< F > * matcher
virtual bool Done_() const
virtual MatchType Type(bool test) const
Arc::StateId StateId
Definition: table-matcher.h:58
Arc::Weight Weight
TableComposeOptions opts
void TableCompose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const TableComposeOptions &opts=TableComposeOptions())
const Arc & Value() const
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
TableMatcher(const TableMatcher< FST, BackoffMatcher > &matcher, bool safe=false)
virtual TableMatcher< FST, BackoffMatcher > * Copy(bool safe=false) const
virtual const Arc & Value_() const
void SetState(StateId s)
virtual const FST & GetFst() const
Definition: table-matcher.h:86
bool Done() const
TableMatcher(const FST &fst, MatchType match_type, const TableMatcherOptions &opts=TableMatcherOptions())
TableComposeOptions(const TableMatcherOptions &mo, bool c=true, ComposeFilter ft=SEQUENCE_FILTER, MatchType tms=MATCH_OUTPUT)
virtual void SetState_(StateId s)
virtual void SetState_(StateId s)
virtual MatchType Type(bool test) const
Definition: table-matcher.h:98
virtual void Next_()
virtual uint64 Properties(uint64 props) const
TableComposeCache lets us do multiple compositions while caching the same matcher.
std::shared_ptr< Impl > impl_
const Arc & Value()
virtual ~TableMatcherImpl()
Definition: table-matcher.h:88
virtual uint64 Properties(uint64 props) const
TableMatcher is a matcher specialized for the case where the output side of the left FST always has e...
Definition: table-matcher.h:42
TableMatcherOptions opts_
virtual bool Find_(Label label)
Arc::StateId StateId
TableComposeCache(const TableComposeOptions &opts=TableComposeOptions())
fst::StdArc::Label Label
fst::StdArc::Weight Weight
virtual const FST & GetFst() const
virtual TableMatcherImpl< FST > * Copy(bool safe=false) const
ArcIterator< FST > * aiter_
virtual void Next_()
virtual bool Done_() const
virtual const Arc & Value_() const
TableMatcherImpl(const FST &fst, MatchType match_type, const TableMatcherOptions &opts=TableMatcherOptions())
Definition: table-matcher.h:66
bool Find(Label match_label)
bool Find(Label match_label)
virtual bool Find_(Label label)
void SetState(StateId s)
BackoffMatcher backoff_matcher_
TableMatcherImpl< F, BackoffMatcher > Impl
const Arc & Value() const
std::vector< std::vector< ArcId > * > tables_
void Copy(const CuMatrixBase< Real > &src, const CuArray< int32 > &copy_from_indices, CuMatrixBase< Real > *tgt)
Copies elements from src into tgt as given by copy_from_indices.
Definition: cu-math.cc:173