nnet-descriptor-test.cc
Go to the documentation of this file.
1 // nnet3/nnet-descriptor-test.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet3/nnet-descriptor.h"
21 
22 namespace kaldi {
23 namespace nnet3 {
24 
26  if (Rand() % 2 != 0) {
27  return new SimpleForwardingDescriptor(Rand() % num_nodes);
28  } else {
29  int32 r = Rand() % 4;
30  if (r == 0) {
31  Index offset;
32  offset.t = Rand() % 5;
33  offset.x = Rand() % 2;
34  return
36  offset);
37  } else if (r == 1) {
38  std::vector<ForwardingDescriptor*> vec;
39  int32 n = 1 + Rand() % 3;
40  for (int32 i = 0; i < n; i++)
41  vec.push_back(GenRandForwardingDescriptor(num_nodes));
42  return new SwitchingForwardingDescriptor(vec);
43  } else if (r == 2) {
45  GenRandForwardingDescriptor(num_nodes), 1 + Rand() % 4);
46  } else {
48  GenRandForwardingDescriptor(num_nodes),
51  -2 + Rand() % 4);
52  }
53  }
54 }
55 
56 // generates a random descriptor.
58  int32 num_nodes) {
59  if (Rand() % 3 != 0) {
60  bool not_required = (Rand() % 5 == 0);
61  if (not_required)
62  return new OptionalSumDescriptor(GenRandSumDescriptor(num_nodes));
63  else
64  return new SimpleSumDescriptor(GenRandForwardingDescriptor(num_nodes));
65  } else {
66  return new BinarySumDescriptor(
69  GenRandSumDescriptor(num_nodes),
70  GenRandSumDescriptor(num_nodes));
71  }
72 }
73 
74 
75 // generates a random descriptor.
76 void GenRandDescriptor(int32 num_nodes,
77  Descriptor *desc) {
78  int32 num_parts = 1 + Rand() % 3;
79  std::vector<SumDescriptor*> parts;
80  for (int32 part = 0; part < num_parts; part++)
81  parts.push_back(GenRandSumDescriptor(num_nodes));
82  *desc = Descriptor(parts);
83 
84 }
85 
86 
87 // This function tests both the I/O for the descriptors, and the
88 // Copy() function.
90  for (int32 i = 0; i < 100; i++) {
91  int32 num_nodes = 1 + Rand() % 5;
92  std::vector<std::string> node_names(num_nodes);
93  for (int32 i = 0; i < node_names.size(); i++) {
94  std::ostringstream ostr;
95  ostr << "a" << (i+1);
96  node_names[i] = ostr.str();
97  }
98  Descriptor desc;
99  std::ostringstream ostr;
100  GenRandDescriptor(num_nodes, &desc);
101  desc.WriteConfig(ostr, node_names);
102 
103  Descriptor desc2(desc), desc3, desc4;
104  desc3 = desc;
105  std::vector<std::string> tokens;
106  DescriptorTokenize(ostr.str(), &tokens);
107  tokens.push_back("end of input");
108  std::istringstream istr(ostr.str());
109  const std::string *next_token = &(tokens[0]);
110  bool ans = desc4.Parse(node_names, &next_token);
111  KALDI_ASSERT(ans);
112 
113  std::ostringstream ostr2;
114  desc2.WriteConfig(ostr2, node_names);
115  std::ostringstream ostr3;
116  desc3.WriteConfig(ostr3, node_names);
117  std::ostringstream ostr4;
118  desc4.WriteConfig(ostr4, node_names);
119 
120  KALDI_ASSERT(ostr.str() == ostr2.str());
121  KALDI_ASSERT(ostr.str() == ostr3.str());
122  KALDI_LOG << "x = " << ostr.str();
123  KALDI_LOG << "y = " << ostr4.str();
124  if (ostr.str() != ostr4.str()) {
125  KALDI_WARN << "x and y differ: checking that it's due to Offset normalization.";
126  KALDI_ASSERT(ostr.str().find("Offset(Offset") != std::string::npos ||
127  (ostr.str().find("Offset(") != std::string::npos &&
128  ostr.str().find(", 0)") != std::string::npos));
129  }
130  }
131 }
132 
133 
134 // This function tests GeneralDescriptor, but only for correctly-normalized input.
136  for (int32 i = 0; i < 100; i++) {
137  int32 num_nodes = 1 + Rand() % 5;
138  std::vector<std::string> node_names(num_nodes);
139  for (int32 i = 0; i < node_names.size(); i++) {
140  std::ostringstream ostr;
141  ostr << "a" << (i+1);
142  node_names[i] = ostr.str();
143  }
144  Descriptor desc;
145  std::ostringstream ostr;
146  GenRandDescriptor(num_nodes, &desc);
147  desc.WriteConfig(ostr, node_names);
148 
149  Descriptor desc2(desc), desc3;
150  desc3 = desc;
151  std::vector<std::string> tokens;
152  DescriptorTokenize(ostr.str(), &tokens);
153  tokens.push_back("end of input");
154  std::istringstream istr(ostr.str());
155  const std::string *next_token = &(tokens[0]);
156 
157 
158  GeneralDescriptor *gen_desc = GeneralDescriptor::Parse(node_names,
159  &next_token);
160 
161  if (*next_token != "end of input")
162  KALDI_ERR << "Parsing Descriptor, expected end of input but got "
163  << "'" << *next_token << "'";
164 
165  Descriptor *desc4 = gen_desc->ConvertToDescriptor();
166  std::ostringstream ostr2;
167  desc4->WriteConfig(ostr2, node_names);
168  KALDI_LOG << "Original descriptor was: " << ostr.str();
169  KALDI_LOG << "Parsed descriptor was: " << ostr2.str();
170  if (ostr2.str() != ostr.str())
171  KALDI_WARN << "Strings differed. Check manually.";
172 
173  delete gen_desc;
174  delete desc4;
175  }
176 }
177 
178 
179 // normalizes the text form of a descriptor.
180 std::string NormalizeTextDescriptor(const std::vector<std::string> &node_names,
181  const std::string &desc_str) {
182  std::vector<std::string> tokens;
183  DescriptorTokenize(desc_str, &tokens);
184  tokens.push_back("end of input");
185  const std::string *next_token = &(tokens[0]);
186  GeneralDescriptor *gen_desc = GeneralDescriptor::Parse(node_names,
187  &next_token);
188  if (*next_token != "end of input")
189  KALDI_ERR << "Parsing Descriptor, expected end of input but got "
190  << "'" << *next_token << "'";
191  Descriptor *desc = gen_desc->ConvertToDescriptor();
192  std::ostringstream ostr;
193  desc->WriteConfig(ostr, node_names);
194  delete gen_desc;
195  delete desc;
196  KALDI_LOG << "Result of normalizing " << desc_str << " is: " << ostr.str();
197  return ostr.str();
198 }
199 
201  std::vector<std::string> names;
202  names.push_back("a");
203  names.push_back("b");
204  names.push_back("c");
205  names.push_back("d");
206  KALDI_ASSERT(NormalizeTextDescriptor(names, "a") == "a");
207  KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(-1.0, a)") == "Scale(-1, a)");
208  KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(-1.0, Scale(-2.0, a))") == "Scale(2, a)");
209  KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(2.0, Sum(Scale(2.0, a), b, c))") ==
210  "Sum(Scale(4, a), Sum(Scale(2, b), Scale(2, c)))");
211  KALDI_ASSERT(NormalizeTextDescriptor(names, "Const(1.0, 512)") == "Const(1, 512)");
212  KALDI_ASSERT(NormalizeTextDescriptor(names, "Sum(Const(1.0, 512), Scale(-1.0, a))") ==
213  "Sum(Const(1, 512), Scale(-1, a))");
214  KALDI_ASSERT(NormalizeTextDescriptor(names, "Offset(Offset(a, 3, 5), 2, 1)")
215  == "Offset(a, 5, 6)");
216 
217  KALDI_ASSERT(NormalizeTextDescriptor(names, "Offset(Sum(a, b), 2, 1)") ==
218  "Sum(Offset(a, 2, 1), Offset(b, 2, 1))");
219  KALDI_ASSERT(NormalizeTextDescriptor(names, "Sum(Append(a, b), Append(c, d))") ==
220  "Append(Sum(a, c), Sum(b, d))");
221  KALDI_ASSERT(NormalizeTextDescriptor(names, "Append(Append(a, b), Append(c, d))") ==
222  "Append(a, b, c, d)");
223  KALDI_ASSERT(NormalizeTextDescriptor(names, "Sum(a, b, c, d)") ==
224  "Sum(a, Sum(b, Sum(c, d)))");
225  KALDI_ASSERT(NormalizeTextDescriptor(names, "Sum(a)") == "a");
226  KALDI_ASSERT(NormalizeTextDescriptor(names, "Offset(a, 0)") == "a");
227 }
228 
229 } // namespace nnet3
230 } // namespace kaldi
231 
232 int main() {
233  using namespace kaldi;
234  using namespace kaldi::nnet3;
235 
236 
240 
241 
242  KALDI_LOG << "Nnet descriptor tests succeeded.";
243 
244  return 0;
245 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main()
This is the case of class SumDescriptor, in which we contain just one term, and that term is optional...
void UnitTestGeneralDescriptorSpecial()
bool Parse(const std::vector< std::string > &node_names, const std::string **next_token)
This class is only used when parsing Descriptors.
SimpleForwardingDescriptor is the base-case of ForwardingDescriptor, consisting of a source node in t...
bool DescriptorTokenize(const std::string &input, std::vector< std::string > *tokens)
This function tokenizes input when parsing Descriptor configuration values.
Definition: nnet-parse.cc:30
BinarySumDescriptor can represent either A + B, or (A if defined, else B).
kaldi::int32 int32
ForwardingDescriptor * GenRandForwardingDescriptor(int32 num_nodes)
For use in clockwork RNNs and the like, this forwarding-descriptor rounds the time-index t down to th...
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
This is an abstract base-class.
Chooses from different inputs based on the the time index modulo (the number of ForwardingDescriptors...
struct rnnlm::@11::@12 n
#define KALDI_ERR
Definition: kaldi-error.h:147
This is the normal base-case of SumDescriptor which just wraps a ForwardingDescriptor.
#define KALDI_WARN
Definition: kaldi-error.h:150
A ForwardingDescriptor describes how we copy data from another NetworkNode, or from multiple other Ne...
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
void GenRandDescriptor(int32 num_nodes, Descriptor *desc)
std::string NormalizeTextDescriptor(const std::vector< std::string > &node_names, const std::string &desc_str)
Offsets in &#39;t&#39; and &#39;x&#39; values of other ForwardingDescriptors.
void WriteConfig(std::ostream &os, const std::vector< std::string > &node_names) const
SumDescriptor * GenRandSumDescriptor(int32 num_nodes)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
This ForwardingDescriptor modifies the indexes (n, t, x) by replacing one of them (normally t) with a...
void UnitTestGeneralDescriptor()
This file contains class definitions for classes ForwardingDescriptor, SumDescriptor and Descriptor...
#define KALDI_LOG
Definition: kaldi-error.h:153