factor-inl.h
Go to the documentation of this file.
1 // fstext/factor-inl.h
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_FSTEXT_FACTOR_INL_H_
21 #define KALDI_FSTEXT_FACTOR_INL_H_
22 
23 #include "util/stl-utils.h"
24 // Do not include this file directly. It is included by factor.h.
25 
26 namespace fst {
27 
28 // GetStateProperties takes in an FST and a number "max_state" which is the
29 // highest numbered state in the FST (this could be fst.NumStates()-1 for an
30 // ExpandedFst, or derived from some kind of traversal). It outputs a vector
31 // numbered from 0..max_state, of type FstStateProperties which is a bitmask
32 // with information about the states.
33 
34 // GetStateProperties has not been tested directly (only implicitly via
35 // testing Factor).
36 template<class Arc>
37 void GetStateProperties(const Fst<Arc> &fst,
38  typename Arc::StateId max_state,
39  std::vector<StatePropertiesType> *props) {
40  typedef typename Arc::StateId StateId;
41  typedef typename Arc::Weight Weight;
42  assert(props != NULL);
43  props->clear();
44  if (fst.Start() < 0) return; // Empty fst.
45  props->resize(max_state+1, 0);
46  assert(fst.Start() <= max_state);
47  (*props)[fst.Start()] |= kStateInitial;
48  for (StateId s = 0; s <= max_state; s++) {
49  StatePropertiesType &s_info = (*props)[s];
50  for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
51  const Arc &arc = aiter.Value();
52  if (arc.ilabel != 0) s_info |= kStateIlabelsOut;
53  if (arc.olabel != 0) s_info |= kStateOlabelsOut;
54  StateId nexts = arc.nextstate;
55  assert(nexts <= max_state); // or input was invalid.
56  StatePropertiesType &nexts_info = (*props)[nexts];
57  if (s_info&kStateArcsOut) s_info |= kStateMultipleArcsOut;
58  s_info |= kStateArcsOut;
59  if (nexts_info&kStateArcsIn) nexts_info |= kStateMultipleArcsIn;
60  nexts_info |= kStateArcsIn;
61  }
62  if (fst.Final(s) != Weight::Zero()) s_info |= kStateFinal;
63  }
64 }
65 
66 
67 
68 template<class Arc, class I>
69 void Factor(const Fst<Arc> &fst, MutableFst<Arc> *ofst,
70  std::vector<std::vector<I> > *symbols_out) {
72  typedef typename Arc::StateId StateId;
73  typedef typename Arc::Label Label;
74  typedef typename Arc::Weight Weight;
75  assert(symbols_out != NULL);
76  ofst->DeleteStates();
77  if (fst.Start() < 0) return; // empty FST.
78  std::vector<StateId> order;
79  DfsOrderVisitor<Arc> dfs_order_visitor(&order);
80  DfsVisit(fst, &dfs_order_visitor);
81  assert(order.size() > 0);
82  StateId max_state = *(std::max_element(order.begin(), order.end()));
83  std::vector<StatePropertiesType> state_properties;
84  GetStateProperties(fst, max_state, &state_properties);
85 
86  std::vector<bool> remove(max_state+1); // if true, will remove this state.
87 
88  // Now identify states that will be removed (made the middle of a chain).
89  // The basic rule is that if the FstStateProperties equals
90  // (kStateArcsIn|kStateArcsOut) or (kStateArcsIn|kStateArcsOut|kStateIlabelsOut),
91  // then it is in the middle of a chain. This eliminates state with
92  // multiple input or output arcs, final states, and states with arcs out
93  // that have olabels [we assume these are pushed to the left, so occur on the
94  // 1st arc of a chain.
95 
96  for (StateId i = 0; i <= max_state; i++)
97  remove[i] = (state_properties[i] == (kStateArcsIn|kStateArcsOut)
98  || state_properties[i] == (kStateArcsIn|kStateArcsOut|kStateIlabelsOut));
99  std::vector<StateId> state_mapping(max_state+1, kNoStateId);
100 
101  typedef unordered_map<std::vector<I>, Label, kaldi::VectorHasher<I> > SymbolMapType;
102  SymbolMapType symbol_mapping;
103  Label symbol_counter = 0;
104  {
105  std::vector<I> eps;
106  symbol_mapping[eps] = symbol_counter++;
107  }
108  std::vector<I> this_sym; // a temporary used inside the loop.
109  for (size_t i = 0; i < order.size(); i++) {
110  StateId state = order[i];
111  if (!remove[state]) { // Process this state...
112  StateId &new_state = state_mapping[state];
113  if (new_state == kNoStateId) new_state = ofst->AddState();
114  for (ArcIterator<Fst<Arc> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
115  Arc arc = aiter.Value();
116  if (arc.ilabel == 0) this_sym.clear();
117  else {
118  this_sym.resize(1);
119  this_sym[0] = arc.ilabel;
120  }
121  while (remove[arc.nextstate]) {
122  ArcIterator<Fst<Arc> > aiter2(fst, arc.nextstate);
123  assert(!aiter2.Done());
124  const Arc &nextarc = aiter2.Value();
125  arc.weight = Times(arc.weight, nextarc.weight);
126  assert(nextarc.olabel == 0);
127  if (nextarc.ilabel != 0) this_sym.push_back(nextarc.ilabel);
128  assert(static_cast<Label>(static_cast<I>(nextarc.ilabel))
129  == nextarc.ilabel); // check within integer range.
130  arc.nextstate = nextarc.nextstate;
131  }
132  StateId &new_nextstate = state_mapping[arc.nextstate];
133  if (new_nextstate == kNoStateId) new_nextstate = ofst->AddState();
134  arc.nextstate = new_nextstate;
135  if (symbol_mapping.count(this_sym) != 0) arc.ilabel = symbol_mapping[this_sym];
136  else arc.ilabel = symbol_mapping[this_sym] = symbol_counter++;
137  ofst->AddArc(new_state, arc);
138  }
139  if (fst.Final(state) != Weight::Zero())
140  ofst->SetFinal(new_state, fst.Final(state));
141  }
142  }
143  ofst->SetStart(state_mapping[fst.Start()]);
144 
145  // Now output the symbol sequences.
146  symbols_out->resize(symbol_counter);
147  for (typename SymbolMapType::const_iterator iter = symbol_mapping.begin();
148  iter != symbol_mapping.end(); ++iter) {
149  (*symbols_out)[iter->second] = iter->first;
150  }
151 }
152 
153 template<class Arc>
154 void Factor(const Fst<Arc> &fst, MutableFst<Arc> *ofst1,
155  MutableFst<Arc> *ofst2) {
156  typedef typename Arc::Label Label;
157  std::vector<std::vector<Label> > symbols;
158  Factor(fst, ofst2, &symbols);
159  CreateFactorFst(symbols, ofst1);
160 }
161 
162 template<class Arc, class I>
163 void ExpandInputSequences(const std::vector<std::vector<I> > &sequences,
164  MutableFst<Arc> *fst) {
166  typedef typename Arc::StateId StateId;
167  typedef typename Arc::Label Label;
168  typedef typename Arc::Weight Weight;
169  fst->SetInputSymbols(NULL);
170  size_t size = sequences.size();
171  if (sequences.size() > 0) assert(sequences[0].size() == 0); // should be eps.
172  StateId num_states_at_start = fst->NumStates();
173  for (StateId s = 0; s < num_states_at_start; s++) {
174  StateId num_arcs = fst->NumArcs(s);
175  for (StateId aidx = 0; aidx < num_arcs; aidx++) {
176  ArcIterator<MutableFst<Arc> > aiter(*fst, s);
177  aiter.Seek(aidx);
178  Arc arc = aiter.Value();
179 
180  Label ilabel = arc.ilabel;
181  Label dest_state = arc.nextstate;
182  if (ilabel != 0) { // non-eps [nothing to do if eps]...
183  assert(ilabel < static_cast<Label>(size));
184  size_t len = sequences[ilabel].size();
185  if (len <= 1) {
186  if (len == 0) arc.ilabel = 0;
187  else arc.ilabel = sequences[ilabel][0];
188  MutableArcIterator<MutableFst<Arc> > mut_aiter(fst, s);
189  mut_aiter.Seek(aidx);
190  mut_aiter.SetValue(arc);
191  } else { // len>=2. Must create new states...
192  StateId curstate = -1; // keep compiler happy: this value never used.
193  for (size_t n = 0; n < len; n++) { // adding/modifying "len" arcs.
194  StateId nextstate;
195  if (n < len-1) {
196  nextstate = fst->AddState();
197  assert(nextstate >= num_states_at_start);
198  } else nextstate = dest_state; // going back to original arc's
199  // destination.
200  if (n == 0) {
201  arc.ilabel = sequences[ilabel][0];
202  arc.nextstate = nextstate;
203  MutableArcIterator<MutableFst<Arc> > mut_aiter(fst, s);
204  mut_aiter.Seek(aidx);
205  mut_aiter.SetValue(arc);
206  } else {
207  arc.ilabel = sequences[ilabel][n];
208  arc.olabel = 0;
209  arc.weight = Weight::One();
210  arc.nextstate = nextstate;
211  fst->AddArc(curstate, arc);
212  }
213  curstate = nextstate;
214  }
215  }
216  }
217  }
218  }
219 }
220 
221 
222 template<class Arc, class I>
224 public:
225  Arc operator ()(const Arc &arc_in) {
226  Arc ans = arc_in;
227  if (to_remove_set_.count(ans.ilabel) != 0) ans.ilabel = 0; // remove this symbol
228  return ans;
229  }
230  MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
231  MapSymbolsAction InputSymbolsAction() { return MAP_CLEAR_SYMBOLS; }
232  MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
233  uint64 Properties(uint64 props) const {
234  // remove the following as we don't know now if any of them are true.
235  uint64 to_remove = kAcceptor|kNotAcceptor|kIDeterministic|kNonIDeterministic|
236  kNoEpsilons|kNoIEpsilons|kILabelSorted|kNotILabelSorted;
237  return props & ~to_remove;
238  }
239  RemoveSomeInputSymbolsMapper(const std::vector<I> &to_remove):
240  to_remove_set_(to_remove) {
242  assert(to_remove_set_.count(0) == 0); // makes no sense to remove epsilon.
243  }
244 private:
246 };
247 
248 
249 template<class Arc, class I>
250 void CreateFactorFst(const std::vector<std::vector<I> > &sequences,
251  MutableFst<Arc> *fst) {
253  typedef typename Arc::StateId StateId;
254  typedef typename Arc::Label Label;
255  typedef typename Arc::Weight Weight;
256 
257  assert(fst != NULL);
258  fst->DeleteStates();
259  StateId loopstate = fst->AddState();
260  assert(loopstate == 0);
261  fst->SetStart(0);
262  fst->SetFinal(0, Weight::One());
263  if (sequences.size() != 0) assert(sequences[0].size() == 0); // can't replace epsilon...
264 
265  for (Label olabel = 1; olabel < static_cast<Label>(sequences.size()); olabel++) {
266  size_t len = sequences[olabel].size();
267  if (len == 0) {
268  Arc arc(0, olabel, Weight::One(), loopstate);
269  fst->AddArc(loopstate, arc);
270  } else {
271  StateId curstate = loopstate;
272  for (size_t i = 0; i < len; i++) {
273  StateId nextstate = (i == len-1 ? loopstate : fst->AddState());
274  Arc arc(sequences[olabel][i], (i == 0 ? olabel : 0), Weight::One(), nextstate);
275  fst->AddArc(curstate, arc);
276  curstate = nextstate;
277  }
278  }
279  }
280  fst->SetProperties(kOLabelSorted, kOLabelSorted);
281 }
282 
283 
284 template<class Arc, class I>
285 void CreateMapFst(const std::vector<I> &symbol_map,
286  MutableFst<Arc> *fst) {
288  typedef typename Arc::StateId StateId;
289  typedef typename Arc::Label Label;
290  typedef typename Arc::Weight Weight;
291 
292  assert(fst != NULL);
293  fst->DeleteStates();
294  StateId loopstate = fst->AddState();
295  assert(loopstate == 0);
296  fst->SetStart(0);
297  fst->SetFinal(0, Weight::One());
298  assert(symbol_map.empty() || symbol_map[0] == 0); // FST cannot map epsilon to something else.
299  for (Label olabel = 1; olabel < static_cast<Label>(symbol_map.size()); olabel++) {
300  Arc arc(symbol_map[olabel], olabel, Weight::One(), loopstate);
301  fst->AddArc(loopstate, arc);
302  }
303 }
304 
305 
306 
307 
308 } // end namespace fst.
309 
310 #endif
fst::StdArc::StateId StateId
uint64 Properties(uint64 props) const
Definition: factor-inl.h:233
A hashing function-object for vectors.
Definition: stl-utils.h:216
#define KALDI_ASSERT_IS_INTEGER_TYPE(I)
Definition: kaldi-utils.h:133
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
Arc operator()(const Arc &arc_in)
Definition: factor-inl.h:225
void GetStateProperties(const Fst< Arc > &fst, typename Arc::StateId max_state, std::vector< StatePropertiesType > *props)
This function works out various properties of the states in the FST, using the bit properties defined...
Definition: factor-inl.h:37
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
MapSymbolsAction InputSymbolsAction()
Definition: factor-inl.h:231
unsigned char StatePropertiesType
Definition: factor.h:122
struct rnnlm::@11::@12 n
void CreateFactorFst(const std::vector< std::vector< I > > &sequences, MutableFst< Arc > *fst)
The function CreateFactorFst will create an FST that expands out the "factors" that are the indices o...
Definition: factor-inl.h:250
void Factor(const Fst< Arc > &fst, MutableFst< Arc > *ofst, std::vector< std::vector< I > > *symbols_out)
Factor identifies linear chains of states with an olabel (if any) only on the first arc of the chain...
Definition: factor-inl.h:69
RemoveSomeInputSymbolsMapper(const std::vector< I > &to_remove)
Definition: factor-inl.h:239
fst::StdArc::Label Label
fst::StdArc::Weight Weight
MapSymbolsAction OutputSymbolsAction()
Definition: factor-inl.h:232
kaldi::ConstIntegerSet< I > to_remove_set_
Definition: factor-inl.h:245
void CreateMapFst(const std::vector< I > &symbol_map, MutableFst< Arc > *fst)
CreateMapFst will create an FST representing this symbol_map.
Definition: factor-inl.h:285
void ExpandInputSequences(const std::vector< std::vector< I > > &sequences, MutableFst< Arc > *fst)
ExpandInputSequences expands out the input symbols into sequences of input symbols.
Definition: factor-inl.h:163