fstrmsymbols.cc
Go to the documentation of this file.
1 // fstbin/fstrmsymbols.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "fst/fstlib.h"
25 #include "fstext/fstext-utils.h"
26 #include "fstext/kaldi-fst-io.h"
27 
28 namespace fst {
29 // we can move these functions elsewhere later, if they are needed in other
30 // places.
31 
32 template<class Arc, class I>
33 void RemoveArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
34  VectorFst<Arc> *fst) {
35  typedef typename Arc::StateId StateId;
36 
37  kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
38 
39  StateId num_states = fst->NumStates();
40  StateId dead_state = fst->AddState();
41  for (StateId s = 0; s < num_states; s++) {
42  for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
43  !iter.Done(); iter.Next()) {
44  if (symbol_set.count(iter.Value().ilabel) != 0) {
45  Arc arc = iter.Value();
46  arc.nextstate = dead_state;
47  iter.SetValue(arc);
48  }
49  }
50  }
51  // Connect() will actually remove the arcs, and the dead state.
52  Connect(fst);
53  if (fst->NumStates() == 0)
54  KALDI_WARN << "After Connect(), fst was empty.";
55 }
56 
57 template<class Arc, class I>
58 void PenalizeArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
59  float penalty,
60  VectorFst<Arc> *fst) {
61  typedef typename Arc::StateId StateId;
62  typedef typename Arc::Label Label;
63  typedef typename Arc::Weight Weight;
64 
65  Weight penalty_weight(penalty);
66 
67  kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
68 
69  StateId num_states = fst->NumStates();
70  for (StateId s = 0; s < num_states; s++) {
71  for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
72  !iter.Done(); iter.Next()) {
73  if (symbol_set.count(iter.Value().ilabel) != 0) {
74  Arc arc = iter.Value();
75  arc.weight = Times(arc.weight, penalty_weight);
76  iter.SetValue(arc);
77  }
78  }
79  }
80 }
81 
82 }
83 
84 
85 int main(int argc, char *argv[]) {
86  try {
87  using namespace kaldi;
88  using namespace fst;
89  using kaldi::int32;
90 
91  bool apply_to_output = false;
92  bool remove_arcs = false;
93  float penalty = -std::numeric_limits<BaseFloat>::infinity();
94 
95  const char *usage =
96  "With no options, replaces a subset of symbols with epsilon, wherever\n"
97  "they appear on the input side of an FST."
98  "With --remove-arcs=true, will remove arcs that contain these symbols\n"
99  "on the input\n"
100  "With --penalty=<float>, will add the specified penalty to the\n"
101  "cost of any arc that has one of the given symbols on its input side\n"
102  "In all cases, the option --apply-to-output=true (or for\n"
103  "back-compatibility, --remove-from-output=true) makes this apply\n"
104  "to the output side.\n"
105  "\n"
106  "Usage: fstrmsymbols [options] <in-disambig-list> [<in.fst> [<out.fst>]]\n"
107  "E.g: fstrmsymbols in.list < in.fst > out.fst\n"
108  "<in-disambig-list> is an rxfilename specifying a file containing list of integers\n"
109  "representing symbols, in text form, one per line.\n";
110 
111  ParseOptions po(usage);
112  po.Register("remove-from-output", &apply_to_output, "If true, this applies to symbols "
113  "on the output, not the input, side. (For back compatibility; use "
114  "--apply-to-output insead)");
115  po.Register("apply-to-output", &apply_to_output, "If true, this applies to symbols "
116  "on the output, not the input, side.");
117  po.Register("remove-arcs", &remove_arcs, "If true, instead of converting the symbol "
118  "to <eps>, remove the arcs.");
119  po.Register("penalty", &penalty, "If specified, instead of converting "
120  "the symbol to <eps>, penalize the arc it is on by adding this "
121  "value to its cost.");
122 
123 
124  po.Read(argc, argv);
125 
126  if (remove_arcs &&
127  penalty != -std::numeric_limits<BaseFloat>::infinity())
128  KALDI_ERR << "--remove-arc and --penalty options are mutually exclusive";
129 
130  if (po.NumArgs() < 1 || po.NumArgs() > 3) {
131  po.PrintUsage();
132  exit(1);
133  }
134 
135  std::string disambig_rxfilename = po.GetArg(1),
136  fst_rxfilename = po.GetOptArg(2),
137  fst_wxfilename = po.GetOptArg(3);
138 
139  VectorFst<StdArc> *fst = CastOrConvertToVectorFst(
140  ReadFstKaldiGeneric(fst_rxfilename));
141 
142  std::vector<int32> disambig_in;
143  if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_in))
144  KALDI_ERR << "fstrmsymbols: Could not read disambiguation symbols from "
145  << (disambig_rxfilename == "" ? "standard input" : disambig_rxfilename);
146 
147  if (apply_to_output) Invert(fst);
148  if (remove_arcs) {
149  RemoveArcsWithSomeInputSymbols(disambig_in, fst);
150  } else if (penalty != -std::numeric_limits<BaseFloat>::infinity()) {
151  PenalizeArcsWithSomeInputSymbols(disambig_in, penalty, fst);
152  } else {
153  RemoveSomeInputSymbols(disambig_in, fst);
154  }
155  if (apply_to_output) Invert(fst);
156 
157  WriteFstKaldi(*fst, fst_wxfilename);
158 
159  delete fst;
160  return 0;
161  } catch(const std::exception &e) {
162  std::cerr << e.what();
163  return -1;
164  }
165 }
166 
167 /* some test examples:
168 
169  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols "echo 3; echo 4|" | fstprint
170  # should produce:
171  # 0 0 1 1
172  # 0 0 0 2
173  # 0
174 
175  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --apply-to-output=true "echo 2; echo 3|" | fstprint
176  # should produce:
177  # 0 0 1 1
178  # 0 0 3 0
179  # 0
180 
181 
182  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --remove-arcs=true "echo 3; echo 4|" | fstprint
183  # should produce:
184  # 0 0 1 1
185  # 0
186 
187  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --penalty=2 "echo 3; echo 4; echo 5|" | fstprint
188 # should produce:
189  # 0 0 1 1
190  # 0 0 3 2 2
191  # 0
192 
193 */
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
Definition: kaldi-fst-io.cc:45
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
Definition: fstrmsymbols.cc:85
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
void PenalizeArcsWithSomeInputSymbols(const std::vector< I > &symbols_in, float penalty, VectorFst< Arc > *fst)
Definition: fstrmsymbols.cc:58
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void RemoveArcsWithSomeInputSymbols(const std::vector< I > &symbols_in, VectorFst< Arc > *fst)
Definition: fstrmsymbols.cc:33
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::StdArc::Label Label
fst::StdArc::Weight Weight
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void WriteFstKaldi(std::ostream &os, bool binary, const VectorFst< Arc > &t)
VectorFst< StdArc > * CastOrConvertToVectorFst(Fst< StdArc > *fst)
Definition: kaldi-fst-io.cc:94
bool ReadIntegerVectorSimple(const std::string &rxfilename, std::vector< int32 > *list)
ReadFromList attempts to read this list of integers, one per line, from the given file...
std::string GetOptArg(int param) const
void RemoveSomeInputSymbols(const std::vector< I > &to_remove, MutableFst< Arc > *fst)
RemoveSomeInputSymbols removes any symbol that appears in "to_remove", from the input side of the FST...