43 #ifndef KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_ 44 #define KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_ 55 #include <unordered_map> 56 using std::unordered_map;
63 #include <fst/cache.h> 64 #include <fst/test-properties.h> 77 Label il = 0, Label ol = 0)
78 :
CacheOptions(opts), delta(d), extra_ilabel(il), extra_olabel(ol) {}
81 float d, Label il = 0, Label ol = 0)
82 : delta(d), extra_ilabel(il), extra_olabel(ol) {}
91 template <
class A,
class F>
96 using FstImpl<A>::SetType;
97 using FstImpl<A>::SetProperties;
98 using FstImpl<A>::Properties;
99 using FstImpl<A>::SetInputSymbols;
100 using FstImpl<A>::SetOutputSymbols;
102 using CacheBaseImpl< CacheState<A> >::HasStart;
103 using CacheBaseImpl< CacheState<A> >::HasFinal;
104 using CacheBaseImpl< CacheState<A> >::HasArcs;
113 typedef typename Store::State
State;
118 Element(StateId s, Weight w) : state(s), weight(w) {}
130 SetType(
"factor-weight");
131 uint64 props = fst.Properties(kFstProperties,
false);
132 SetProperties(FactorWeightProperties(props), kCopyProperties);
134 SetInputSymbols(fst.InputSymbols());
135 SetOutputSymbols(fst.OutputSymbols());
140 fst_(impl.fst_->
Copy(true)),
142 extra_ilabel_(impl.extra_ilabel_),
143 extra_olabel_(impl.extra_olabel_) {
144 SetType(
"factor-weight");
145 SetProperties(impl.Properties(), kCopyProperties);
146 SetInputSymbols(impl.InputSymbols());
147 SetOutputSymbols(impl.OutputSymbols());
152 StateId s = fst_->Start();
155 StateId start = this->FindState(
Element(fst_->Start(), Weight::One()));
156 this->SetStart(start);
163 const Element &e = elements_[s];
165 if (e.
state == kNoStateId) {
166 FactorIterator fit(e.
weight);
173 if (e.
weight != Weight::One()) {
176 w = fst_->Final(e.
state);
177 FactorIterator fit(w);
182 this->SetFinal(s, w);
217 typename ElementMap::iterator eit = element_map_.find(e);
218 if (eit != element_map_.end()) {
219 return (*eit).second;
221 StateId s = elements_.size();
222 elements_.push_back(e);
223 element_map_.insert(std::pair<const Element, StateId>(e, s));
231 CHECK(static_cast<size_t>(s) < elements_.size());
233 if (e.weight != Weight::One()) {
234 FactorIterator fit(e.weight);
236 if (e.state != kNoStateId) {
237 StateId dest = FindState(
Element(e.state, Weight::One()));
238 PushArc(s,
Arc(extra_ilabel_, extra_olabel_, e.weight, dest));
241 const std::pair<Weight, Weight> &p = fit.Value();
242 StateId dest = FindState(
Element(e.state, p.second.Quantize(delta_)));
243 PushArc(s,
Arc(extra_ilabel_, extra_olabel_, p.first, dest));
246 CHECK(e.state != kNoStateId);
247 for (ArcIterator< Fst<A> > ait(*fst_, e.state);
250 const A &arc = ait.Value();
251 FactorIterator fit(arc.weight);
253 StateId dest = FindState(
Element(arc.nextstate, Weight::One()));
254 PushArc(s,
Arc(arc.ilabel, arc.olabel, arc.weight, dest));
256 const std::pair<Weight, Weight> &p = fit.Value();
257 StateId dest = FindState(
Element(arc.nextstate, p.second.Quantize(delta_)));
258 PushArc(s,
Arc(arc.ilabel, arc.olabel, p.first, dest));
262 Weight final_w = fst_->Final(e.state);
263 if (final_w != Weight::Zero()) {
264 FactorIterator fit(final_w);
266 const std::pair<Weight, Weight> &p = fit.Value();
267 StateId dest = FindState(
Element(kNoStateId, p.second.Quantize(delta_)));
268 PushArc(s,
Arc(extra_ilabel_, extra_olabel_, p.first, dest));
288 return static_cast<size_t>(x.
state * kPrime + x.
weight.Hash());
291 static const int kPrime = 7853;
294 typedef unordered_map<Element, StateId, ElementKey, ElementEqual>
ElementMap;
296 std::unique_ptr<const Fst<A>>
fst_;
324 template <
class A,
class F>
326 public ImplToFst<internal::TrivialFactorWeightFstImpl<A, F>> {
334 typedef DefaultCacheStore<Arc>
Store;
335 typedef typename Store::State
State;
342 :
ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
353 inline void InitStateIterator(StateIteratorData<A> *data)
const override;
356 GetMutableImpl()->InitArcIterator(s, data);
368 template<
class A,
class F>
378 template <
class A,
class F>
386 if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
390 template <
class A,
class F>
392 StateIteratorData<A> *data)
const {
393 data->base =
new StateIterator< TrivialFactorWeightFst<A, F> >(*this);
fst::StdArc::StateId StateId
size_t NumInputEpsilons(StateId s)
StateIterator(const TrivialFactorWeightFst< A, F > &fst)
TrivialFactorWeightOptions(float d, Label il=0, Label ol=0)
TrivialFactorWeightOptions(const CacheOptions &opts, float d, Label il=0, Label ol=0)
TrivialFactorWeightFstImpl(const TrivialFactorWeightFstImpl< A, F > &impl)
TrivialFactorWeightFst takes as template parameter a FactorIterator as defined above.
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
std::vector< Element > elements_
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
TrivialFactorWeightFstImpl(const Fst< A > &fst, const TrivialFactorWeightOptions< A > &opts)
bool operator()(const Element &x, const Element &y) const
size_t NumArcs(StateId s)
void InitStateIterator(StateIteratorData< A > *data) const override
StateId FindState(const Element &e)
size_t operator()(const Element &x) const
TrivialFactorWeightFst(const Fst< A > &fst)
void InitArcIterator(StateId s, ArcIteratorData< A > *data)
DefaultCacheStore< A > Store
TrivialFactorWeightFst(const Fst< A > &fst, const TrivialFactorWeightOptions< A > &opts)
unordered_map< Element, StateId, ElementKey, ElementEqual > ElementMap
std::unique_ptr< const Fst< A > > fst_
ArcIterator(const TrivialFactorWeightFst< A, F > &fst, StateId s)
TrivialFactorWeightFst(const TrivialFactorWeightFst< A, F > &fst, bool copy)
size_t NumOutputEpsilons(StateId s)
DefaultCacheStore< Arc > Store
fst::StdArc::Weight Weight
internal::TrivialFactorWeightFstImpl< A, F > Impl
TrivialFactorWeightOptions()
TrivialFactorWeightFst< A, F > * Copy(bool copy=false) const override
Arc::StateId NumArcs(const ExpandedFst< Arc > &fst)
Returns the total number of arcs in an FST.
Element(StateId s, Weight w)
void Copy(const CuMatrixBase< Real > &src, const CuArray< int32 > ©_from_indices, CuMatrixBase< Real > *tgt)
Copies elements from src into tgt as given by copy_from_indices.