20 #ifndef KALDI_FSTEXT_TABLE_MATCHER_H_ 21 #define KALDI_FSTEXT_TABLE_MATCHER_H_ 22 #include <fst/fstlib.h> 23 #include <fst/fst-decl.h> 52 template<
class F,
class BackoffMatcher = SortedMatcher<F> >
68 match_type_(match_type),
70 loop_(match_type == MATCH_INPUT ?
71 Arc(kNoLabel, 0, Weight::One(), kNoStateId) :
72 Arc(0, kNoLabel, Weight::One(), kNoStateId)),
74 s_(kNoStateId), opts_(opts),
75 backoff_matcher_(fst, match_type)
77 assert(opts_.min_table_size > 0);
78 if (match_type == MATCH_INPUT)
79 assert(fst_->Properties(kILabelSorted,
true) == kILabelSorted);
80 else if (match_type == MATCH_OUTPUT)
81 assert(fst_->Properties(kOLabelSorted,
true) == kOLabelSorted);
83 assert(0 &&
"Invalid FST properties");
86 virtual const FST &
GetFst()
const {
return *fst_; }
89 std::vector<ArcId> *
const empty = ((std::vector<ArcId>*)(NULL)) + 1;
90 for (
size_t i = 0;
i < tables_.size();
i++) {
91 if (tables_[
i] != NULL && tables_[
i] != empty)
98 virtual MatchType
Type(
bool test)
const {
107 if (match_type_ == MATCH_NONE)
108 LOG(FATAL) <<
"TableMatcher: bad match type";
110 std::vector<ArcId> *
const empty = ((std::vector<ArcId>*)(NULL)) + 1;
111 if (static_cast<size_t>(s) >= tables_.size()) {
113 tables_.resize(s+1, NULL);
115 std::vector<ArcId>* &this_table_ = tables_[s];
116 if (this_table_ == empty) {
117 backoff_matcher_.SetState(s);
119 }
else if (this_table_ == NULL) {
120 ArcId num_arcs = fst_->NumArcs(s);
121 if (num_arcs == 0 || num_arcs < opts_.min_table_size) {
123 backoff_matcher_.SetState(s);
126 ArcIterator<FST> aiter(*fst_, s);
127 aiter.SetFlags(kArcNoCache|(match_type_ == MATCH_OUTPUT?kArcOLabelValue:kArcILabelValue),
128 kArcNoCache|kArcValueFlags);
131 aiter.Seek(num_arcs - 1);
132 Label highest_label = (match_type_ == MATCH_OUTPUT ?
133 aiter.Value().olabel : aiter.Value().ilabel);
134 if ((highest_label+1) * opts_.table_ratio > num_arcs) {
136 backoff_matcher_.SetState(s);
140 this_table_ =
new std::vector<ArcId> (highest_label+1, kNoStateId);
142 for (aiter.Seek(0); !aiter.Done(); aiter.Next(), pos++) {
143 Label label = (match_type_ == MATCH_OUTPUT ?
144 aiter.Value().olabel : aiter.Value().ilabel);
145 assert((
size_t)label <= (
size_t)highest_label);
146 if ((*this_table_)[label] == kNoStateId) (*this_table_)[label] = pos;
152 aiter_ =
new ArcIterator<FST>(*fst_, s);
153 aiter_->SetFlags(kArcNoCache, kArcNoCache);
161 if (!aiter_)
return backoff_matcher_.Find(match_label);
163 match_label_ = match_label;
164 current_loop_ = (match_label == 0);
167 match_label_ = (match_label_ == kNoLabel ? 0 : match_label_);
168 if (static_cast<size_t>(match_label_) < tables_[s_]->size() &&
169 (*(tables_[s_]))[match_label_] != kNoStateId) {
170 aiter_->Seek( (*(tables_[s_]))[match_label_] );
173 return current_loop_;
178 return current_loop_ ? loop_ : aiter_->Value();
180 return backoff_matcher_.Value();
186 current_loop_ =
false;
190 backoff_matcher_.Next();
194 if (aiter_ != NULL) {
199 Label label = (match_type_ == MATCH_OUTPUT ?
200 aiter_->Value().olabel : aiter_->Value().ilabel);
201 return (label != match_label_);
203 return backoff_matcher_.Done();
206 if (aiter_ != NULL) {
207 return (current_loop_ ? loop_ : aiter_->Value() );
209 return backoff_matcher_.Value();
218 virtual uint64
Properties(uint64 props)
const {
return props; }
223 virtual bool Find_(Label label) {
return Find(label); }
224 virtual bool Done_()
const {
return Done(); }
225 virtual const Arc&
Value_()
const {
return Value(); }
242 template<
class F,
class BackoffMatcher = SortedMatcher<F> >
256 : impl_(std::make_shared<Impl>(fst, match_type, opts)) { }
260 : impl_(matcher.impl_) {
262 LOG(FATAL) <<
"TableMatcher: Safe copy not supported";
266 virtual const FST &
GetFst()
const {
return impl_->GetFst(); }
268 virtual MatchType
Type(
bool test)
const {
return impl_->Type(test); }
270 void SetState(StateId s) {
return impl_->SetState(s); }
272 bool Find(Label match_label) {
return impl_->Find(match_label); }
274 const Arc&
Value()
const {
return impl_->Value(); }
276 void Next() {
return impl_->Next(); }
278 bool Done()
const {
return impl_->Done(); }
280 const Arc &
Value() {
return impl_->Value(); }
286 virtual uint64
Properties(uint64 props)
const {
return impl_->Properties(props); }
291 virtual void SetState_(StateId s) { impl_->SetState(s); }
292 virtual bool Find_(Label label) {
return impl_->Find(label); }
293 virtual bool Done_()
const {
return impl_->Done(); }
294 virtual const Arc&
Value_()
const {
return impl_->Value(); }
295 virtual void Next_() { impl_->Next(); }
306 bool c =
true, ComposeFilter ft = SEQUENCE_FILTER,
307 MatchType tms = MATCH_OUTPUT)
310 table_match_type(MATCH_OUTPUT) { }
316 MutableFst<Arc> *ofst,
321 if (opts.table_match_type == MATCH_OUTPUT) {
323 ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
325 *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
327 assert(opts.table_match_type == MATCH_INPUT) ;
329 ComposeFstImplOptions<SortedMatcher<F>,
TableMatcher<F> > impl_opts(nopts);
331 *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
333 if (opts.connect) Connect(ofst);
349 MutableFst<Arc> *ofst,
352 assert(cache != NULL);
355 if (cache->opts.table_match_type == MATCH_OUTPUT) {
356 ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
357 if (cache->matcher == NULL)
358 cache->matcher =
new TableMatcher<F>(ifst1, MATCH_OUTPUT, cache->opts);
359 impl_opts.matcher1 = cache->matcher->
Copy();
361 *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
363 assert(cache->opts.table_match_type == MATCH_INPUT) ;
364 ComposeFstImplOptions<SortedMatcher<F>,
TableMatcher<F> > impl_opts(nopts);
365 if (cache->matcher == NULL)
367 impl_opts.matcher2 = cache->matcher->
Copy();
368 *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
370 if (cache->opts.connect) Connect(ofst);
fst::StdArc::StateId StateId
TableMatcher< F > * matcher
virtual bool Done_() const
virtual MatchType Type(bool test) const
void TableCompose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const TableComposeOptions &opts=TableComposeOptions())
const Arc & Value() const
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
TableMatcher(const TableMatcher< FST, BackoffMatcher > &matcher, bool safe=false)
virtual TableMatcher< FST, BackoffMatcher > * Copy(bool safe=false) const
virtual const Arc & Value_() const
virtual const FST & GetFst() const
TableMatcher(const FST &fst, MatchType match_type, const TableMatcherOptions &opts=TableMatcherOptions())
TableComposeOptions(const TableMatcherOptions &mo, bool c=true, ComposeFilter ft=SEQUENCE_FILTER, MatchType tms=MATCH_OUTPUT)
virtual void SetState_(StateId s)
virtual void SetState_(StateId s)
virtual MatchType Type(bool test) const
virtual uint64 Properties(uint64 props) const
TableComposeCache lets us do multiple compositions while caching the same matcher.
std::shared_ptr< Impl > impl_
virtual ~TableMatcherImpl()
virtual uint64 Properties(uint64 props) const
ComposeFilter filter_type
TableMatcher is a matcher specialized for the case where the output side of the left FST always has e...
TableMatcherOptions opts_
virtual bool Find_(Label label)
TableComposeCache(const TableComposeOptions &opts=TableComposeOptions())
fst::StdArc::Weight Weight
virtual const FST & GetFst() const
virtual TableMatcherImpl< FST > * Copy(bool safe=false) const
ArcIterator< FST > * aiter_
MatchType table_match_type
virtual bool Done_() const
virtual const Arc & Value_() const
TableMatcherImpl(const FST &fst, MatchType match_type, const TableMatcherOptions &opts=TableMatcherOptions())
bool Find(Label match_label)
bool Find(Label match_label)
virtual bool Find_(Label label)
BackoffMatcher backoff_matcher_
TableMatcherImpl< F, BackoffMatcher > Impl
const Arc & Value() const
std::vector< std::vector< ArcId > * > tables_
void Copy(const CuMatrixBase< Real > &src, const CuArray< int32 > ©_from_indices, CuMatrixBase< Real > *tgt)
Copies elements from src into tgt as given by copy_from_indices.