lattice-weight-test.cc
Go to the documentation of this file.
1 // fstext/lattice-weight-test.cc
2 
3 // Copyright 2009-2011 Microsoft Corporation
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 
11 // http://www.apache.org/licenses/LICENSE-2.0
12 
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 #include "base/kaldi-math.h"
20 #include "fstext/lattice-weight.h"
21 
22 namespace fst {
23 using std::vector;
24 using std::cout;
25 // these typedefs are the same as in ../lat/kaldi-lattice.h, but
26 // just used here for testing (doesn't matter if they get out of
27 // sync).
28 typedef float BaseFloat;
29 
31 
33 
36 
37 
38 LatticeWeight RandomLatticeWeight() {
39  int tmp = kaldi::Rand() % 4;
40  if (tmp == 0) {
41  return LatticeWeight::Zero();
42  } else if (tmp == 1) {
43  return LatticeWeight( 1, 2); // sometimes return special values..
44  } else if (tmp == 2) {
45  return LatticeWeight( 2, 1); // this tests more thoroughly certain properties...
46  } else {
47  return LatticeWeight( 100 * kaldi::RandGauss(), 100 * kaldi::RandGauss());
48  }
49 }
50 
51 CompactLatticeWeight RandomCompactLatticeWeight() {
52  LatticeWeight w = RandomLatticeWeight();
53  if (w == LatticeWeight::Zero()) {
54  return CompactLatticeWeight(w, vector<int32>());
55  } else {
56  int32 len = kaldi::Rand() % 4;
57  vector<int32> str;
58  for(int32 i = 0; i < len; i++)
59  str.push_back(kaldi::Rand() % 10 + 1);
60  return CompactLatticeWeight(w, str);
61  }
62 }
63 
65  for(int32 i = 0; i < 100; i++) {
66  LatticeWeight l1 = RandomLatticeWeight(), l2 = RandomLatticeWeight();
67  LatticeWeight l3 = Plus(l1, l2);
68  LatticeWeight l4 = Times(l1, l2);
69  BaseFloat f1 = l1.Value1() + l1.Value2(), f2 = l2.Value1() + l2.Value2(), f3 = l3.Value1() + l3.Value2(),
70  f4 = l4.Value1() + l4.Value2();
71  kaldi::AssertEqual(std::min(f1, f2), f3);
72  kaldi::AssertEqual(f1 + f2, f4);
73 
74  KALDI_ASSERT(Plus(l3, l3) == l3);
75  KALDI_ASSERT(Plus(l1, l2) == Plus(l2, l1)); // commutativity of plus
76  KALDI_ASSERT(Times(l1, l2) == Times(l2, l1)); // commutativity of Times (true for this semiring, not always)
77  KALDI_ASSERT(Plus(l3, LatticeWeight::Zero()) == l3); // x + 0 = x
78  KALDI_ASSERT(Times(l3, LatticeWeight::One()) == l3); // x * 1 = x
80 
81  KALDI_ASSERT(l3.Reverse().Reverse() == l3);
82 
83  NaturalLess<LatticeWeight> nl;
84  bool a = nl(l1, l2);
85  bool b = (Plus(l1, l2) == l1 && l1 != l2);
86  KALDI_ASSERT(a == b);
87 
88  KALDI_ASSERT(Compare(l1, Plus(l1, l2)) != 1); // so do not have l1 > l1 + l2
89  LatticeWeight l5 = RandomLatticeWeight(), l6 = RandomLatticeWeight();
90  {
91  LatticeWeight wa = Times(Plus(l1, l2), Plus(l5, l6)),
92  wb = Plus(Times(l1, l5), Plus(Times(l1, l6),
93  Plus(Times(l2, l5), Times(l2, l6))));
94  if (!ApproxEqual(wa, wb)) {
95  std::cout << "l1 = " << l1 << ", l2 = " << l2
96  << ", l5 = " << l5 << ", l6 = " << l6 << "\n";
97  std::cout << "ERROR: " << wa << " != " << wb << "\n";
98  }
99  // KALDI_ASSERT(Times(Plus(l1, l2), Plus(l5, l6))
100  // == Plus(Times(l1, l5), Plus(Times(l1,l6),
101  // Plus(Times(l2, l5), Times(l2, l6))))); // * distributes over +
102  }
103  KALDI_ASSERT(l1.Member() && l2.Member() && l3.Member() && l4.Member()
104  && l5.Member() && l6.Member());
105  if (l2 != LatticeWeight::Zero())
106  KALDI_ASSERT(ApproxEqual(Divide(Times(l1, l2), l2), l1)); // (a*b) / b = a if b != 0
107  KALDI_ASSERT(ApproxEqual(l1, l1.Quantize()));
108 
109  std::ostringstream s1;
110  s1 << l1;
111  std::istringstream s2(s1.str());
112  s2 >> l2;
113  KALDI_ASSERT(ApproxEqual(l1, l2, 0.001));
114  std::cout << s1.str() << '\n';
115  {
116  std::ostringstream s1b;
117  l1.Write(s1b);
118  std::istringstream s2b(s1b.str());
119  l3.Read(s2b);
120  KALDI_ASSERT(l1 == l3);
121  }
122  }
123 }
124 
125 
127  for(int32 i = 0; i < 100; i++) {
128  CompactLatticeWeight l1 = RandomCompactLatticeWeight(), l2 = RandomCompactLatticeWeight();
129  CompactLatticeWeight l3 = Plus(l1, l2);
130  CompactLatticeWeight l4 = Times(l1, l2);
131 
132  KALDI_ASSERT(Plus(l3, l3) == l3);
133  KALDI_ASSERT(Plus(l1, l2) == Plus(l2, l1)); // commutativity of plus
134  KALDI_ASSERT(Plus(l3, CompactLatticeWeight::Zero()) == l3); // x + 0 = x
135  KALDI_ASSERT(Times(l3, CompactLatticeWeight::One()) == l3); // x * 1 = x
137  NaturalLess<CompactLatticeWeight> nl;
138  bool a = nl(l1, l2);
139  bool b = (Plus(l1, l2) == l1 && l1 != l2);
140  KALDI_ASSERT(a == b);
141 
142  KALDI_ASSERT(Compare(l1, Plus(l1, l2)) != 1); // so do not have l1 > l1 + l2
143  CompactLatticeWeight l5 = RandomCompactLatticeWeight(), l6 = RandomCompactLatticeWeight();
144  KALDI_ASSERT(Times(Plus(l1, l2), Plus(l5, l6)) ==
145  Plus(Times(l1, l5), Plus(Times(l1, l6),
146  Plus(Times(l2, l5), Times(l2, l6))))); // * distributes over +
147  KALDI_ASSERT(l1.Member() && l2.Member() && l3.Member() && l4.Member()
148  && l5.Member() && l6.Member());
149  if (l2 != CompactLatticeWeight::Zero()) {
150  KALDI_ASSERT(ApproxEqual(Divide(Times(l1, l2), l2, DIVIDE_RIGHT), l1)); // (a*b) / b = a if b != 0
151  KALDI_ASSERT(ApproxEqual(Divide(Times(l2, l1), l2, DIVIDE_LEFT), l1)); // (a*b) / b = a if b != 0
152  }
153  KALDI_ASSERT(ApproxEqual(l1, l1.Quantize()));
154 
155  std::ostringstream s1;
156  s1 << l1;
157  std::istringstream s2(s1.str());
158  s2 >> l2;
159  KALDI_ASSERT(ApproxEqual(l1, l2));
160  std::cout << s1.str() << '\n';
161 
162  {
163  std::ostringstream s1b;
164  l1.Write(s1b);
165  std::istringstream s2b(s1b.str());
166  l3.Read(s2b);
167  KALDI_ASSERT(l1 == l3);
168  }
169 
171  std::cout << "l5 = " << l5 << '\n';
172  std::cout << "l6 = " << l6 << '\n';
173  l1 = divisor(l5, l6);
174  std::cout << "div = " << l1 << '\n';
175  if (l1 != CompactLatticeWeight::Zero()) {
176  l2 = Divide(l5, l1, DIVIDE_LEFT);
177  l3 = Divide(l6, l1, DIVIDE_LEFT);
178  std::cout << "l2 = " << l2 << '\n';
179  std::cout << "l3 = " << l3 << '\n';
180  l4 = divisor(l2, l3); // make sure l2 is now one.
181  std::cout << "l4 = " << l4 << '\n';
183  } else {
185  && l6 == CompactLatticeWeight::Zero());
186  }
187  }
188 }
189 
190 
191 }
192 
193 int main() {
196 }
197 
LatticeWeightTpl Quantize(float delta=kDelta) const
CompactLatticeWeight RandomCompactLatticeWeight()
LatticeWeightTpl< FloatType > Divide(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, DivideType typ=DIVIDE_ANY)
LatticeWeightTpl< FloatType > Reverse() const
static const LatticeWeightTpl One()
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
std::istream & Read(std::istream &strm)
LatticeWeightTpl< FloatType > Plus(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
float RandGauss(struct RandomState *state=NULL)
Definition: kaldi-math.h:155
kaldi::int32 int32
bool ApproxEqual(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2, float delta=kDelta)
int main()
std::istream & Read(std::istream &strm)
void LatticeWeightTest()
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
static const CompactLatticeWeightTpl< WeightType, IntType > One()
static const LatticeWeightTpl Zero()
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
float BaseFloat
int Compare(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
LatticeWeight RandomLatticeWeight()
static const CompactLatticeWeightTpl< WeightType, IntType > Zero()
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
CompactLatticeWeightTpl Quantize(float delta=kDelta) const
LatticeWeightTpl< BaseFloat > LatticeWeight
CompactLatticeWeightTpl< LatticeWeight, int32 > CompactLatticeWeight
CompactLatticeWeightCommonDivisorTpl< LatticeWeight, int32 > CompactLatticeWeightCommonDivisor
void CompactLatticeWeightTest()