arpa-lm-compiler-test.cc
Go to the documentation of this file.
1 // lm/arpa-lm-compiler-test.cc
2 
3 // Copyright 2009-2011 Gilles Boulianne
4 // Copyright 2016 Smart Action LLC (kkm)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include <iostream>
22 #include <string>
23 #include <sstream>
24 
25 #include "base/kaldi-error.h"
26 #include "base/kaldi-math.h"
27 #include "lm/arpa-lm-compiler.h"
28 #include "util/kaldi-io.h"
29 
30 namespace kaldi {
31 
32 // Predefine some symbol values, because any integer is as good than any other.
33 enum {
34  kEps = 0,
37 };
38 
39 // Number of random sentences for coverage test.
40 static const int kRandomSentences = 50;
41 
42 // Creates an FST that generates any sequence of symbols taken from given
43 // symbol table. The FST is then associated with the symbol table.
44 static fst::StdVectorFst* CreateGenFst(bool seps, const fst::SymbolTable* pst) {
46  genFst->SetInputSymbols(pst);
47  genFst->SetOutputSymbols(pst);
48 
49  fst::StdArc::StateId midId = genFst->AddState();
50  if (!seps) {
51  fst::StdArc::StateId initId = genFst->AddState();
52  fst::StdArc::StateId finalId = genFst->AddState();
53  genFst->SetStart(initId);
54  genFst->SetFinal(finalId, fst::StdArc::Weight::One());
55  genFst->AddArc(initId, fst::StdArc(kBos, kBos, 0, midId));
56  genFst->AddArc(midId, fst::StdArc(kEos, kEos, 0, finalId));
57  } else {
58  genFst->SetStart(midId);
59  genFst->SetFinal(midId, fst::StdArc::Weight::One());
60  }
61 
62  // Add a loop for each symbol in the table except the four special ones.
63  fst::SymbolTableIterator si(*pst);
64  for (si.Reset(); !si.Done(); si.Next()) {
65  if (si.Value() == kBos || si.Value() == kEos ||
66  si.Value() == kEps || si.Value() == kDisambig)
67  continue;
68  genFst->AddArc(midId, fst::StdArc(si.Value(), si.Value(),
69  fst::StdArc::Weight::One(), midId));
70  }
71  return genFst;
72 }
73 
74 // Compile given ARPA file.
75 ArpaLmCompiler* Compile(bool seps, const std::string &infile) {
76  ArpaParseOptions options;
77  fst::SymbolTable symbols;
78  // Use spaces on special symbols, so we rather fail than read them by mistake.
79  symbols.AddSymbol(" <eps>", kEps);
80  symbols.AddSymbol(" #0", kDisambig);
81  options.bos_symbol = symbols.AddSymbol("<s>", kBos);
82  options.eos_symbol = symbols.AddSymbol("</s>", kEos);
84 
85  // Tests in this form cannot be run with epsilon substitution, unless every
86  // random path is also fitted with a #0-transducing self-loop.
87  ArpaLmCompiler* lm_compiler =
88  new ArpaLmCompiler(options,
89  seps ? kDisambig : 0,
90  &symbols);
91  {
92  Input ki(infile);
93  lm_compiler->Read(ki.Stream());
94  }
95  return lm_compiler;
96 }
97 
98 // Add a state to an FSA after last_state, add a form last_state to the new
99 // state, and return the new state.
101  fst::StdArc::StateId last_state,
102  int64 symbol) {
103  fst::StdArc::StateId next_state = fst->AddState();
104  fst->AddArc(last_state, fst::StdArc(symbol, symbol, 0, next_state));
105  return next_state;
106 }
107 
108 // Add a disambiguator-generating self loop to every state of an FST.
109 void AddSelfLoops(fst::StdMutableFst* fst) {
110  for (fst::StateIterator<fst::StdMutableFst> siter(*fst);
111  !siter.Done(); siter.Next()) {
112  fst->AddArc(siter.Value(),
113  fst::StdArc(kEps, kDisambig, 0, siter.Value()));
114  }
115 }
116 
117 // Compiles infile and then runs kRandomSentences random coverage tests on the
118 // compiled FST.
119 bool CoverageTest(bool seps, const std::string &infile) {
120  // Compile ARPA model.
121  ArpaLmCompiler* lm_compiler = Compile(seps, infile);
122 
123  // Create an FST that generates any sequence of symbols taken from the model
124  // output.
125  fst::StdVectorFst* genFst =
126  CreateGenFst(seps, lm_compiler->Fst().OutputSymbols());
127 
128  int num_successes = 0;
129  for (int32 i = 0; i < kRandomSentences; ++i) {
130  // Generate a random sentence FST.
131  fst::StdVectorFst sentence;
132  RandGen(*genFst, &sentence);
133  if (seps)
134  AddSelfLoops(&sentence);
135 
136  fst::ArcSort(lm_compiler->MutableFst(), fst::StdOLabelCompare());
137 
138  // The past must successfully compose with the LM FST.
139  fst::StdVectorFst composition;
140  Compose(sentence, lm_compiler->Fst(), &composition);
141  if (composition.Start() != fst::kNoStateId)
142  ++num_successes;
143  }
144 
145  delete genFst;
146  delete lm_compiler;
147 
148  bool ok = num_successes == kRandomSentences;
149  if (!ok) {
150  KALDI_WARN << "Coverage test failed on " << infile << ": composed "
151  << num_successes << "/" << kRandomSentences;
152  }
153  return ok;
154 }
155 
156 bool ScoringTest(bool seps, const std::string &infile, const std::string& sentence,
157  float expected) {
158  ArpaLmCompiler* lm_compiler = Compile(seps, infile);
159  const fst::SymbolTable* symbols = lm_compiler->Fst().InputSymbols();
160 
161  // Create a sentence FST for scoring.
162  fst::StdVectorFst sentFst;
163  fst::StdArc::StateId state = sentFst.AddState();
164  sentFst.SetStart(state);
165  if (!seps) {
166  state = AddToChainFsa(&sentFst, state, kBos);
167  }
168  std::stringstream ss(sentence);
169  std::string word;
170  while (ss >> word) {
171  int64 word_sym = symbols->Find(word);
172  KALDI_ASSERT(word_sym != -1);
173  state = AddToChainFsa(&sentFst, state, word_sym);
174  }
175  if (!seps) {
176  state = AddToChainFsa(&sentFst, state, kEos);
177  }
178  if (seps) {
179  AddSelfLoops(&sentFst);
180  }
181  sentFst.SetFinal(state, 0);
182  sentFst.SetOutputSymbols(symbols);
183 
184  // Do the composition and extract final weight.
185  fst::StdVectorFst composed;
186  fst::Compose(sentFst, lm_compiler->Fst(), &composed);
187  delete lm_compiler;
188 
189  if (composed.Start() == fst::kNoStateId) {
190  KALDI_WARN << "Test sentence " << sentence << " did not compose "
191  << "with the language model FST\n";
192  return false;
193  }
194 
195  std::vector<fst::StdArc::Weight> shortest;
196  fst::ShortestDistance(composed, &shortest, true);
197  float actual = shortest[composed.Start()].Value();
198 
199  bool ok = ApproxEqual(expected, actual);
200  if (!ok) {
201  KALDI_WARN << "Scored " << sentence << " in " << infile
202  << ": Expected=" << expected << " actual=" << actual;
203  }
204  return ok;
205 }
206 
207 bool ThrowsExceptionTest(bool seps, const std::string &infile) {
208  try {
209  // Make memory cleanup easy in both cases of try-catch block.
210  std::unique_ptr<ArpaLmCompiler> compiler(Compile(seps, infile));
211  return false;
212  } catch (const KaldiFatalError&) {
213  return true;
214  }
215 }
216 
217 } // namespace kaldi
218 
219 bool RunAllTests(bool seps) {
220  bool ok = true;
221  ok &= kaldi::CoverageTest(seps, "test_data/missing_backoffs.arpa");
222  ok &= kaldi::CoverageTest(seps, "test_data/unused_backoffs.arpa");
223  ok &= kaldi::CoverageTest(seps, "test_data/input.arpa");
224 
225  ok &= kaldi::ScoringTest(seps, "test_data/input.arpa", "b b b a", 59.2649);
226  ok &= kaldi::ScoringTest(seps, "test_data/input.arpa", "a b", 4.36082);
227 
228  ok &= kaldi::ThrowsExceptionTest(seps, "test_data/missing_bos.arpa");
229 
230  if (!ok) {
231  KALDI_WARN << "Tests " << (seps ? "with" : "without")
232  << " epsilon substitution FAILED";
233  }
234  return ok;
235 }
236 
237 int main(int argc, char *argv[]) {
238  bool ok = true;
239 
240  ok &= RunAllTests(false); // Without disambiguators (old behavior).
241  ok &= RunAllTests(true); // With epsilon substitution (new behavior).
242 
243  if (ok) {
244  KALDI_LOG << "All tests passed";
245  return 0;
246  } else {
247  KALDI_WARN << "Test FAILED";
248  return 1;
249  }
250 }
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
ArpaLmCompiler * Compile(bool seps, const std::string &infile)
fst::StdArc::StateId AddToChainFsa(fst::StdMutableFst *fst, fst::StdArc::StateId last_state, int64 symbol)
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
fst::StdArc StdArc
Options that control ArpaFileParser.
void AddSelfLoops(const TransitionModel &trans_model, const std::vector< int32 > &disambig_syms, BaseFloat self_loop_scale, bool reorder, bool check_no_self_loops, fst::VectorFst< fst::StdArc > *fst)
For context, see AddSelfLoops().
Definition: hmm-utils.cc:602
const fst::StdVectorFst & Fst() const
bool ThrowsExceptionTest(bool seps, const std::string &infile)
bool RunAllTests(bool seps)
static fst::StdVectorFst * CreateGenFst(bool seps, const fst::SymbolTable *pst)
kaldi::int32 int32
fst::StdVectorFst * MutableFst()
fst::StdVectorFst StdVectorFst
Add novel words to the symbol table.
Kaldi fatal runtime error exception.
Definition: kaldi-error.h:89
int32 eos_symbol
Symbol for </s>, Required non-epsilon.
std::istream & Stream()
Definition: kaldi-io.cc:826
bool CoverageTest(bool seps, const std::string &infile)
#define KALDI_WARN
Definition: kaldi-error.h:150
void Read(std::istream &is)
Read ARPA LM file from a stream.
int32 bos_symbol
Symbol for <s>, Required non-epsilon.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int main(int argc, char *argv[])
bool ScoringTest(bool seps, const std::string &infile, const std::string &sentence, float expected)
OovHandling oov_handling
How to handle OOV words in the file.
static const int kRandomSentences
#define KALDI_LOG
Definition: kaldi-error.h:153
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265