arpa-file-parser-test.cc
Go to the documentation of this file.
1 // lm/arpa-file-parser-test.cc
2 
3 // Copyright 2016 Smart Action Company LLC (kkm)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 
11 // http://www.apache.org/licenses/LICENSE-2.0
12 
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
25 #include <iomanip>
26 #include <iostream>
27 #include <string>
28 #include <sstream>
29 #include <vector>
30 
31 #include "base/kaldi-common.h"
32 #include "fst/fstlib.h"
33 #include "lm/arpa-file-parser.h"
34 
35 namespace kaldi {
36 namespace {
37 
38 const int kMaxOrder = 3;
39 
40 struct NGramTestData {
42  float logprob;
43  int32 words[kMaxOrder];
44  float backoff;
45 };
46 
47 std::ostream& operator<<(std::ostream &os, const NGramTestData &data) {
48  std::ios::fmtflags saved_state(os.flags());
49  os << std::fixed << std::setprecision(6);
50 
51  os << data.logprob << ' ';
52  for (int i = 0; i < kMaxOrder; ++i) os << data.words[i] << ' ';
53  os << data.backoff << " // Line " << data.line_number;
54 
55  os.flags(saved_state);
56  return os;
57 }
58 
59 // This does not own the array pointer, and uset to simplify passing expected
60 // result to TestableArpaFileParser::Verify.
61 template <class T>
62 struct CountedArray {
63  template <size_t N>
64  CountedArray(T(&array)[N]) : array(array), count(N) { }
65  const T *array;
66  const size_t count;
67 };
68 
69 template <class T, size_t N>
70 inline CountedArray<T> MakeCountedArray(T(&array)[N]) {
71  return CountedArray<T>(array);
72 }
73 
74 class TestableArpaFileParser : public ArpaFileParser {
75  public:
76  TestableArpaFileParser(const ArpaParseOptions &options,
77  fst::SymbolTable *symbols)
78  : ArpaFileParser(options, symbols),
79  header_available_(false),
80  read_complete_(false),
81  last_order_(0) { }
82  void Validate(CountedArray<int32> counts, CountedArray<NGramTestData> ngrams);
83 
84  private:
85  // ArpaFileParser overrides.
86  virtual void HeaderAvailable();
87  virtual void ConsumeNGram(const NGram& ngram);
88  virtual void ReadComplete();
89 
93  std::vector<NGramTestData> ngrams_;
94 };
95 
96 void TestableArpaFileParser::HeaderAvailable() {
99  header_available_ = true;
100  KALDI_ASSERT(NgramCounts().size() <= kMaxOrder);
101 }
102 
103 void TestableArpaFileParser::ConsumeNGram(const NGram& ngram) {
106  KALDI_ASSERT(ngram.words.size() <= NgramCounts().size());
107  KALDI_ASSERT(ngram.words.size() >= last_order_);
108  last_order_ = ngram.words.size();
109 
110  NGramTestData entry = { 0 };
111  entry.line_number = LineNumber();
112  entry.logprob = ngram.logprob;
113  entry.backoff = ngram.backoff;
114  std::copy(ngram.words.begin(), ngram.words.end(), entry.words);
115  ngrams_.push_back(entry);
116 }
117 
118 void TestableArpaFileParser::ReadComplete() {
121  read_complete_ = true;
122 }
123 
124 bool CompareNgrams(const NGramTestData &actual,
125  NGramTestData expected) {
126  expected.logprob *= Log(10.0);
127  expected.backoff *= Log(10.0);
128  if (actual.line_number != expected.line_number
129  || !std::equal(actual.words, actual.words + kMaxOrder,
130  expected.words)
131  || !ApproxEqual(actual.logprob, expected.logprob)
132  || !ApproxEqual(actual.backoff, expected.backoff)) {
133  KALDI_WARN << "Actual n-gram [" << actual
134  << "] differs from expected [" << expected << "]";
135  return false;
136  }
137  return true;
138 }
139 
140 void TestableArpaFileParser::Validate(
141  CountedArray<int32> expect_counts,
142  CountedArray<NGramTestData> expect_ngrams) {
143  // This needs better disagnostics probably.
144  KALDI_ASSERT(NgramCounts().size() == expect_counts.count);
145  KALDI_ASSERT(std::equal(NgramCounts().begin(), NgramCounts().end(),
146  expect_counts.array));
147 
148  KALDI_ASSERT(ngrams_.size() == expect_ngrams.count);
149  // auto mpos = std::mismatch(ngrams_.begin(), ngrams_.end(),
150  // expect_ngrams.array, CompareNgrams);
151  // if (mpos.first != ngrams_.end())
152  // KALDI_ERR << "Maismatch at index " << mpos.first - ngrams_.begin();
153  // TODO: auto above requres C++11, and I cannot spell out the type!!!
154  KALDI_ASSERT(std::equal(ngrams_.begin(), ngrams_.end(),
155  expect_ngrams.array, CompareNgrams));
156 }
157 
158 // Read integer LM (no symbols) with log base conversion.
159 void ReadIntegerLmLogconvExpectSuccess() {
160  KALDI_LOG << "ReadIntegerLmLogconvExpectSuccess()";
161 
162  static std::string integer_lm = "\
163 \\data\\\n\
164 ngram 1=4\n\
165 ngram 2=2\n\
166 ngram 3=2\n\
167 \n\
168 \\1-grams:\n\
169 -5.2\t4\t-3.3\n\
170 -3.4\t5\n\
171 0\t1\t-2.5\n\
172 -4.3\t2\n\
173 \n\
174 \\2-grams:\n\
175 -1.4\t4 5\t-3.2\n\
176 -1.3\t1 4\t-4.2\n\
177 \n\
178 \\3-grams:\n\
179 -0.3\t1 4 5\n\
180 -0.2\t4 5 2\n\
181 \n\
182 \\end\\";
183 
184  int32 expect_counts[] = { 4, 2, 2 };
185  NGramTestData expect_ngrams[] = {
186  { 7, -5.2, { 4, 0, 0 }, -3.3 },
187  { 8, -3.4, { 5, 0, 0 }, 0.0 },
188  { 9, 0.0, { 1, 0, 0 }, -2.5 },
189  { 10, -4.3, { 2, 0, 0 }, 0.0 },
190 
191  { 13, -1.4, { 4, 5, 0 }, -3.2 },
192  { 14, -1.3, { 1, 4, 0 }, -4.2 },
193 
194  { 17, -0.3, { 1, 4, 5 }, 0.0 },
195  { 18, -0.2, { 4, 5, 2 }, 0.0 } };
196 
197  ArpaParseOptions options;
198  options.bos_symbol = 1;
199  options.eos_symbol = 2;
200 
201  TestableArpaFileParser parser(options, NULL);
202  std::istringstream stm(integer_lm, std::ios_base::in);
203  parser.Read(stm);
204  parser.Validate(MakeCountedArray(expect_counts),
205  MakeCountedArray(expect_ngrams));
206 }
207 
208 // \xCE\xB2 = UTF-8 for Greek beta, to churn some UTF-8 cranks.
209 static std::string symbolic_lm = "\
210 We also allow random text coming before the \\data\\\n\
211 section marker. Even this is ok:\n\
212 \n\
213 \\1-grams:\n\
214 \n\
215 and should be ignored before the \\data\\ marker\n\
216 is seen alone by itself on a line.\n\
217 \n\
218 \\data\\\n\
219 ngram 1=4\n\
220 ngram 2=2\n\
221 ngram 3=2\n\
222 \n\
223 \\1-grams: \n\
224 -5.2\ta\t-3.3\n\
225 -3.4\t\xCE\xB2\n\
226 0.0\t<s>\t-2.5\n\
227 -4.3\t</s>\n\
228 \n\
229 \\2-grams:\t\n\
230 -1.5\ta \xCE\xB2\t-3.2\n\
231 -1.3\t<s> a\t-4.2\n\
232 \n\
233 \\3-grams:\n\
234 -0.3\t<s> a \xCE\xB2\n\
235 -0.2\t<s> a </s>\n\
236 \\end\\";
237 
238 // Symbol table that is created with predefined test symbols, "a" but no "b".
239 class TestSymbolTable : public fst::SymbolTable {
240  public:
241  TestSymbolTable() {
242  AddSymbol("<eps>", 0);
243  AddSymbol("<s>", 1);
244  AddSymbol("</s>", 2);
245  AddSymbol("<unk>", 3);
246  AddSymbol("a", 4);
247  }
248 };
249 
250 // Full expected result shared between ReadSymbolicLmNoOovImpl and
251 // ReadSymbolicLmWithOovAddToSymbols().
252 NGramTestData expect_symbolic_full[] = {
253  { 15, -5.2, { 4, 0, 0 }, -3.3 },
254  { 16, -3.4, { 5, 0, 0 }, 0.0 },
255  { 17, 0.0, { 1, 0, 0 }, -2.5 },
256  { 18, -4.3, { 2, 0, 0 }, 0.0 },
257 
258  { 21, -1.5, { 4, 5, 0 }, -3.2 },
259  { 22, -1.3, { 1, 4, 0 }, -4.2 },
260 
261  { 25, -0.3, { 1, 4, 5 }, 0.0 },
262  { 26, -0.2, { 1, 4, 2 }, 0.0 } };
263 
264 // This is run with all possible oov setting and yields same result.
265 void ReadSymbolicLmNoOovImpl(ArpaParseOptions::OovHandling oov) {
266  int32 expect_counts[] = { 4, 2, 2 };
267  TestSymbolTable symbols;
268  symbols.AddSymbol("\xCE\xB2", 5);
269 
270  ArpaParseOptions options;
271  options.bos_symbol = 1;
272  options.eos_symbol = 2;
273  options.unk_symbol = 3;
274  options.oov_handling = oov;
275  TestableArpaFileParser parser(options, &symbols);
276  std::istringstream stm(symbolic_lm, std::ios_base::in);
277  parser.Read(stm);
278  parser.Validate(MakeCountedArray(expect_counts),
279  MakeCountedArray(expect_symbolic_full));
280  KALDI_ASSERT(symbols.NumSymbols() == 6);
281 }
282 
283 void ReadSymbolicLmNoOovTests() {
284  KALDI_LOG << "ReadSymbolicLmNoOovImpl(kRaiseError)";
285  ReadSymbolicLmNoOovImpl(ArpaParseOptions::kRaiseError);
286  KALDI_LOG << "ReadSymbolicLmNoOovImpl(kAddToSymbols)";
287  ReadSymbolicLmNoOovImpl(ArpaParseOptions::kAddToSymbols);
288  KALDI_LOG << "ReadSymbolicLmNoOovImpl(kReplaceWithUnk)";
289  ReadSymbolicLmNoOovImpl(ArpaParseOptions::kReplaceWithUnk);
290  KALDI_LOG << "ReadSymbolicLmNoOovImpl(kSkipNGram)";
291  ReadSymbolicLmNoOovImpl(ArpaParseOptions::kSkipNGram);
292 }
293 
294 // This is run with all possible oov setting and yields same result.
295 void ReadSymbolicLmWithOovImpl(
297  CountedArray<NGramTestData> expect_ngrams,
298  fst::SymbolTable* symbols) {
299  int32 expect_counts[] = { 4, 2, 2 };
300  ArpaParseOptions options;
301  options.bos_symbol = 1;
302  options.eos_symbol = 2;
303  options.unk_symbol = 3;
304  options.oov_handling = oov;
305  TestableArpaFileParser parser(options, symbols);
306  std::istringstream stm(symbolic_lm, std::ios_base::in);
307  parser.Read(stm);
308  parser.Validate(MakeCountedArray(expect_counts), expect_ngrams);
309 }
310 
311 void ReadSymbolicLmWithOovAddToSymbols() {
312  TestSymbolTable symbols;
313  ReadSymbolicLmWithOovImpl(ArpaParseOptions::kAddToSymbols,
314  MakeCountedArray(expect_symbolic_full),
315  &symbols);
316  KALDI_ASSERT(symbols.NumSymbols() == 6);
317  KALDI_ASSERT(symbols.Find("\xCE\xB2") == 5);
318 }
319 
320 void ReadSymbolicLmWithOovReplaceWithUnk() {
321  NGramTestData expect_symbolic_unk_b[] = {
322  { 15, -5.2, { 4, 0, 0 }, -3.3 },
323  { 16, -3.4, { 3, 0, 0 }, 0.0 },
324  { 17, 0.0, { 1, 0, 0 }, -2.5 },
325  { 18, -4.3, { 2, 0, 0 }, 0.0 },
326 
327  { 21, -1.5, { 4, 3, 0 }, -3.2 },
328  { 22, -1.3, { 1, 4, 0 }, -4.2 },
329 
330  { 25, -0.3, { 1, 4, 3 }, 0.0 },
331  { 26, -0.2, { 1, 4, 2 }, 0.0 } };
332 
333  TestSymbolTable symbols;
334  ReadSymbolicLmWithOovImpl(ArpaParseOptions::kReplaceWithUnk,
335  MakeCountedArray(expect_symbolic_unk_b),
336  &symbols);
337  KALDI_ASSERT(symbols.NumSymbols() == 5);
338 }
339 
340 void ReadSymbolicLmWithOovSkipNGram() {
341  NGramTestData expect_symbolic_no_b[] = {
342  { 15, -5.2, { 4, 0, 0 }, -3.3 },
343  { 17, 0.0, { 1, 0, 0 }, -2.5 },
344  { 18, -4.3, { 2, 0, 0 }, 0.0 },
345 
346  { 22, -1.3, { 1, 4, 0 }, -4.2 },
347 
348  { 26, -0.2, { 1, 4, 2 }, 0.0 } };
349 
350  TestSymbolTable symbols;
351  ReadSymbolicLmWithOovImpl(ArpaParseOptions::kSkipNGram,
352  MakeCountedArray(expect_symbolic_no_b),
353  &symbols);
354  KALDI_ASSERT(symbols.NumSymbols() == 5);
355 }
356 
357 void ReadSymbolicLmWithOovTests() {
358  KALDI_LOG << "ReadSymbolicLmWithOovAddToSymbols()";
359  ReadSymbolicLmWithOovAddToSymbols();
360  KALDI_LOG << "ReadSymbolicLmWithOovReplaceWithUnk()";
361  ReadSymbolicLmWithOovReplaceWithUnk();
362  KALDI_LOG << "ReadSymbolicLmWithOovSkipNGram()";
363  ReadSymbolicLmWithOovSkipNGram();
364 }
365 
366 } // namespace
367 } // namespace kaldi
368 
369 int main(int argc, char *argv[]) {
370  kaldi::ReadIntegerLmLogconvExpectSuccess();
371  kaldi::ReadSymbolicLmNoOovTests();
372  kaldi::ReadSymbolicLmWithOovTests();
373 }
int32 words[kMaxOrder]
std::ostream & operator<<(std::ostream &os, const MatrixBase< Real > &M)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
std::vector< NGramTestData > ngrams_
float logprob
kaldi::int32 int32
Add novel words to the symbol table.
const size_t count
double Log(double x)
Definition: kaldi-math.h:100
float backoff
bool header_available_
#define KALDI_WARN
Definition: kaldi-error.h:150
Skip n-gram with OOV word and continue.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int main(int argc, char *argv[])
int32 last_order_
const T * array
bool read_complete_
int32 line_number
#define KALDI_LOG
Definition: kaldi-error.h:153
Replace OOV words with <unk>.
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265