32 #include "fst/fstlib.h" 38 const int kMaxOrder = 3;
40 struct NGramTestData {
47 std::ostream&
operator<<(std::ostream &os,
const NGramTestData &data) {
48 std::ios::fmtflags saved_state(os.flags());
49 os << std::fixed << std::setprecision(6);
51 os << data.logprob <<
' ';
52 for (
int i = 0;
i < kMaxOrder; ++
i) os << data.words[
i] <<
' ';
53 os << data.backoff <<
" // Line " << data.line_number;
55 os.flags(saved_state);
69 template <
class T,
size_t N>
70 inline CountedArray<T> MakeCountedArray(T(&
array)[N]) {
71 return CountedArray<T>(
array);
74 class TestableArpaFileParser :
public ArpaFileParser {
76 TestableArpaFileParser(
const ArpaParseOptions &options,
77 fst::SymbolTable *symbols)
78 : ArpaFileParser(options, symbols),
82 void Validate(CountedArray<int32> counts, CountedArray<NGramTestData> ngrams);
86 virtual void HeaderAvailable();
87 virtual void ConsumeNGram(
const NGram& ngram);
88 virtual void ReadComplete();
96 void TestableArpaFileParser::HeaderAvailable() {
103 void TestableArpaFileParser::ConsumeNGram(
const NGram& ngram) {
106 KALDI_ASSERT(ngram.words.size() <= NgramCounts().size());
110 NGramTestData entry = { 0 };
111 entry.line_number = LineNumber();
112 entry.logprob = ngram.logprob;
113 entry.backoff = ngram.backoff;
114 std::copy(ngram.words.begin(), ngram.words.end(), entry.words);
118 void TestableArpaFileParser::ReadComplete() {
124 bool CompareNgrams(
const NGramTestData &actual,
125 NGramTestData expected) {
126 expected.logprob *=
Log(10.0);
127 expected.backoff *=
Log(10.0);
128 if (actual.line_number != expected.line_number
129 || !std::equal(actual.words, actual.words + kMaxOrder,
132 || !
ApproxEqual(actual.backoff, expected.backoff)) {
134 <<
"] differs from expected [" << expected <<
"]";
140 void TestableArpaFileParser::Validate(
141 CountedArray<int32> expect_counts,
142 CountedArray<NGramTestData> expect_ngrams) {
144 KALDI_ASSERT(NgramCounts().size() == expect_counts.count);
145 KALDI_ASSERT(std::equal(NgramCounts().begin(), NgramCounts().end(),
146 expect_counts.array));
155 expect_ngrams.array, CompareNgrams));
159 void ReadIntegerLmLogconvExpectSuccess() {
160 KALDI_LOG <<
"ReadIntegerLmLogconvExpectSuccess()";
162 static std::string integer_lm =
"\ 184 int32 expect_counts[] = { 4, 2, 2 };
185 NGramTestData expect_ngrams[] = {
186 { 7, -5.2, { 4, 0, 0 }, -3.3 },
187 { 8, -3.4, { 5, 0, 0 }, 0.0 },
188 { 9, 0.0, { 1, 0, 0 }, -2.5 },
189 { 10, -4.3, { 2, 0, 0 }, 0.0 },
191 { 13, -1.4, { 4, 5, 0 }, -3.2 },
192 { 14, -1.3, { 1, 4, 0 }, -4.2 },
194 { 17, -0.3, { 1, 4, 5 }, 0.0 },
195 { 18, -0.2, { 4, 5, 2 }, 0.0 } };
197 ArpaParseOptions options;
198 options.bos_symbol = 1;
199 options.eos_symbol = 2;
201 TestableArpaFileParser parser(options, NULL);
202 std::istringstream stm(integer_lm, std::ios_base::in);
204 parser.Validate(MakeCountedArray(expect_counts),
205 MakeCountedArray(expect_ngrams));
209 static std::string symbolic_lm =
"\ 210 We also allow random text coming before the \\data\\\n\ 211 section marker. Even this is ok:\n\ 215 and should be ignored before the \\data\\ marker\n\ 216 is seen alone by itself on a line.\n\ 230 -1.5\ta \xCE\xB2\t-3.2\n\ 234 -0.3\t<s> a \xCE\xB2\n\ 239 class TestSymbolTable :
public fst::SymbolTable {
242 AddSymbol(
"<eps>", 0);
244 AddSymbol(
"</s>", 2);
245 AddSymbol(
"<unk>", 3);
252 NGramTestData expect_symbolic_full[] = {
253 { 15, -5.2, { 4, 0, 0 }, -3.3 },
254 { 16, -3.4, { 5, 0, 0 }, 0.0 },
255 { 17, 0.0, { 1, 0, 0 }, -2.5 },
256 { 18, -4.3, { 2, 0, 0 }, 0.0 },
258 { 21, -1.5, { 4, 5, 0 }, -3.2 },
259 { 22, -1.3, { 1, 4, 0 }, -4.2 },
261 { 25, -0.3, { 1, 4, 5 }, 0.0 },
262 { 26, -0.2, { 1, 4, 2 }, 0.0 } };
266 int32 expect_counts[] = { 4, 2, 2 };
267 TestSymbolTable symbols;
268 symbols.AddSymbol(
"\xCE\xB2", 5);
270 ArpaParseOptions options;
271 options.bos_symbol = 1;
272 options.eos_symbol = 2;
273 options.unk_symbol = 3;
274 options.oov_handling = oov;
275 TestableArpaFileParser parser(options, &symbols);
276 std::istringstream stm(symbolic_lm, std::ios_base::in);
278 parser.Validate(MakeCountedArray(expect_counts),
279 MakeCountedArray(expect_symbolic_full));
283 void ReadSymbolicLmNoOovTests() {
284 KALDI_LOG <<
"ReadSymbolicLmNoOovImpl(kRaiseError)";
286 KALDI_LOG <<
"ReadSymbolicLmNoOovImpl(kAddToSymbols)";
288 KALDI_LOG <<
"ReadSymbolicLmNoOovImpl(kReplaceWithUnk)";
290 KALDI_LOG <<
"ReadSymbolicLmNoOovImpl(kSkipNGram)";
295 void ReadSymbolicLmWithOovImpl(
297 CountedArray<NGramTestData> expect_ngrams,
298 fst::SymbolTable* symbols) {
299 int32 expect_counts[] = { 4, 2, 2 };
300 ArpaParseOptions options;
301 options.bos_symbol = 1;
302 options.eos_symbol = 2;
303 options.unk_symbol = 3;
304 options.oov_handling = oov;
305 TestableArpaFileParser parser(options, symbols);
306 std::istringstream stm(symbolic_lm, std::ios_base::in);
308 parser.Validate(MakeCountedArray(expect_counts), expect_ngrams);
311 void ReadSymbolicLmWithOovAddToSymbols() {
312 TestSymbolTable symbols;
314 MakeCountedArray(expect_symbolic_full),
320 void ReadSymbolicLmWithOovReplaceWithUnk() {
321 NGramTestData expect_symbolic_unk_b[] = {
322 { 15, -5.2, { 4, 0, 0 }, -3.3 },
323 { 16, -3.4, { 3, 0, 0 }, 0.0 },
324 { 17, 0.0, { 1, 0, 0 }, -2.5 },
325 { 18, -4.3, { 2, 0, 0 }, 0.0 },
327 { 21, -1.5, { 4, 3, 0 }, -3.2 },
328 { 22, -1.3, { 1, 4, 0 }, -4.2 },
330 { 25, -0.3, { 1, 4, 3 }, 0.0 },
331 { 26, -0.2, { 1, 4, 2 }, 0.0 } };
333 TestSymbolTable symbols;
335 MakeCountedArray(expect_symbolic_unk_b),
340 void ReadSymbolicLmWithOovSkipNGram() {
341 NGramTestData expect_symbolic_no_b[] = {
342 { 15, -5.2, { 4, 0, 0 }, -3.3 },
343 { 17, 0.0, { 1, 0, 0 }, -2.5 },
344 { 18, -4.3, { 2, 0, 0 }, 0.0 },
346 { 22, -1.3, { 1, 4, 0 }, -4.2 },
348 { 26, -0.2, { 1, 4, 2 }, 0.0 } };
350 TestSymbolTable symbols;
352 MakeCountedArray(expect_symbolic_no_b),
357 void ReadSymbolicLmWithOovTests() {
358 KALDI_LOG <<
"ReadSymbolicLmWithOovAddToSymbols()";
359 ReadSymbolicLmWithOovAddToSymbols();
360 KALDI_LOG <<
"ReadSymbolicLmWithOovReplaceWithUnk()";
361 ReadSymbolicLmWithOovReplaceWithUnk();
362 KALDI_LOG <<
"ReadSymbolicLmWithOovSkipNGram()";
363 ReadSymbolicLmWithOovSkipNGram();
369 int main(
int argc,
char *argv[]) {
370 kaldi::ReadIntegerLmLogconvExpectSuccess();
371 kaldi::ReadSymbolicLmNoOovTests();
372 kaldi::ReadSymbolicLmWithOovTests();
std::ostream & operator<<(std::ostream &os, const MatrixBase< Real > &M)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
std::vector< NGramTestData > ngrams_
Add novel words to the symbol table.
Skip n-gram with OOV word and continue.
#define KALDI_ASSERT(cond)
int main(int argc, char *argv[])
Replace OOV words with <unk>.
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).