nnet-component-test.cc
Go to the documentation of this file.
1 // nnet3/nnet-component-test.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet3/nnet-nnet.h"
22 #include "nnet3/nnet-test-utils.h"
23 
24 namespace kaldi {
25 namespace nnet3 {
26 // Reset seeds for test time for RandomComponent
27 static void ResetSeed(int32 rand_seed, const Component &c) {
28  RandomComponent *rand_component =
29  const_cast<RandomComponent*>(dynamic_cast<const RandomComponent*>(&c));
30 
31  if (rand_component != NULL) {
32  srand(rand_seed);
33  rand_component->ResetGenerator();
34  }
35 }
36 
37 // this is the same as calling StringsApproxEqual(), except it prints
38 // a warning if it fails.
39 bool CheckStringsApproxEqual(const std::string &a,
40  const std::string &b,
41  int32 tolerance = 3) {
42  if (!StringsApproxEqual(a, b, tolerance)) {
43  KALDI_WARN << "Strings differ: " << a
44  << "\nvs.\n" << b;
45  return false;
46  } else {
47  return true;
48  }
49 }
50 
51 
53  bool binary = (Rand() % 2 == 0);
54  std::ostringstream os1;
55  c->Write(os1, binary);
56  std::istringstream is(os1.str());
57  Component *c2 = Component::ReadNew(is, binary);
58  std::ostringstream os2;
59  c2->Write(os2, binary);
60  if (!binary) {
61  std::string s1 = os1.str(), s2 = os2.str();
63  }
64  delete c2;
65 }
66 
68  Component *c2 = c->Copy();
69  if (!StringsApproxEqual(c->Info(), c2->Info())) {
70  KALDI_ERR << "Expected info strings to be equal: '"
71  << c->Info() << "' vs. '" << c2->Info() << "'";
72  }
73  delete c2;
74 }
75 
77  Component *c2 = c->Copy();
78  Component *c3 = c2->Copy();
79  c3->Add(0.5, *c2);
80  c2->Scale(1.5);
82  delete c2;
83  delete c3;
84 }
85 
87  if (!(c->Properties() & kUpdatableComponent))
88  return;
89  UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(c);
90  KALDI_ASSERT(uc != NULL);
91  UpdatableComponent *uc2 = dynamic_cast<UpdatableComponent*>(uc->Copy());
92  uc2->Scale(0.0);
93  Vector<BaseFloat> params(uc2->NumParameters());
94  uc2->Vectorize(&params);
95  KALDI_ASSERT(params.Min()==0.0 && params.Sum()==0.0);
96  uc->Vectorize(&params);
97  uc2->UnVectorize(params);
99  BaseFloat x = uc2->DotProduct(*uc2), y = uc->DotProduct(*uc),
100  z = uc2->DotProduct(*uc);
101  KALDI_ASSERT(ApproxEqual(x, y) && ApproxEqual(y, z));
102  Vector<BaseFloat> params2(uc2->NumParameters());
103  uc2->Vectorize(&params2);
104  for(int i = 0; i < params.Dim(); i++)
105  KALDI_ASSERT(params(i) == params2(i));
106  delete uc2;
107 }
108 
110  if (!(c->Properties() & kUpdatableComponent))
111  return;
112  UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(c);
113  if (uc == NULL) {
115  "Component returns updatable flag but does not inherit "
116  "from UpdatableComponent");
117  return;
118  }
119  if(!(uc->Properties() & kUpdatableComponent)){
120  // testing that if it declares itself as non-updatable,
121  // Scale() and Add() have no effect.
122  KALDI_ASSERT(uc->NumParameters() == 0);
123  KALDI_ASSERT(uc->DotProduct(*uc) == 0);
124  UpdatableComponent *uc2 = dynamic_cast<UpdatableComponent*>(uc->Copy());
125  uc2->Scale(7.0);
126  uc2->Add(3.0, *uc);
127  KALDI_ASSERT(CheckStringsApproxEqual(uc2->Info(), uc->Info()));
128  uc->Scale(0.0);
129  KALDI_ASSERT(CheckStringsApproxEqual(uc2->Info(), uc->Info()));
130  delete uc2;
131  } else {
132  KALDI_ASSERT(uc->NumParameters() != 0);
133  UpdatableComponent *uc2 = dynamic_cast<UpdatableComponent*>(uc->Copy()),
134  *uc3 = dynamic_cast<UpdatableComponent*>(uc->Copy());
135 
136  // testing some expected invariances of scale and add.
137  uc2->Scale(5.0);
138  uc2->Add(3.0, *uc3);
139  uc3->Scale(8.0);
140  // now they should both be scaled to 8 times the original component.
141  if (!StringsApproxEqual(uc2->Info(), uc3->Info())) {
142  KALDI_ERR << "Expected info strings to be equal: '"
143  << uc2->Info() << "' vs. '" << uc3->Info() << "'";
144  }
145  // testing that scaling by 0.5 works the same whether
146  // done on the vectorized paramters or via Scale().
147  Vector<BaseFloat> vec2(uc->NumParameters());
148  uc2->Vectorize(&vec2);
149  vec2.Scale(0.5);
150  uc2->UnVectorize(vec2);
151  uc3->Scale(0.5);
152  KALDI_ASSERT(CheckStringsApproxEqual(uc2->Info(), uc3->Info()));
153 
154  // testing that Scale(0.0) works the same whether done on the vectorized
155  // paramters or via SetZero(), and that unvectorizing something that's been
156  // zeroed gives us zero parameters.
157  uc2->Vectorize(&vec2);
158  vec2.SetZero();
159  uc2->UnVectorize(vec2);
160  uc3->Scale(0.0);
161  uc3->Vectorize(&vec2);
162  KALDI_ASSERT(uc2->Info() == uc3->Info() && VecVec(vec2, vec2) == 0.0);
163 
164  delete uc2;
165  delete uc3;
166  }
167 }
168 
169 
170 /*
171  This function gets the 'ComponentPrecomputedIndexes*' pointer from
172  a component, given the num-rows in the matrix of inputs we're testing it
173  with. It uses a plausible arrangement of indexes.
174 
175  Note: in this file we primarily test simple components, and simple
176  components don't return precomputed indexes; but we also test a
177  few non-simple components that operate with the same set of indexes
178  on the input and the output. Simple components would return NULL.
179  */
181  int32 num_rows) {
182  std::vector<Index> input_indexes(num_rows);
183  int32 num_t_values;
184  if (num_rows % 3 == 0) { num_t_values = 3; }
185  else if (num_rows % 2 == 0) { num_t_values = 2; }
186  else { num_t_values = 1; }
187 
188  for (int32 i = 0; i < num_rows; i++) {
189  input_indexes[i].n = i % num_t_values;
190  input_indexes[i].x = 0;
191  input_indexes[i].t = i / num_t_values;
192  }
193  std::vector<Index> output_indexes(input_indexes);
194 
195  if (c.Properties()&kReordersIndexes) {
196  c.ReorderIndexes(&input_indexes, &output_indexes);
197  }
198  MiscComputationInfo misc_info;
199  bool need_backprop = true; // just in case.
201  input_indexes,
202  output_indexes,
203  need_backprop);
204  // ans will be NULL in most cases.
205  return ans;
206 }
207 
208 // tests the properties kPropagateAdds, kBackpropAdds,
209 // kBackpropNeedsInput, kBackpropNeedsOutput.
211  int32 properties = c.Properties();
212  Component *c_copy = NULL;
213  int32 rand_seed = Rand();
214 
215  if (RandInt(0, 1) == 0)
216  c_copy = c.Copy(); // This will test backprop with an updatable component.
217  MatrixStrideType input_stride_type = (c.Properties()&kInputContiguous) ?
219  MatrixStrideType output_stride_type = (c.Properties()&kOutputContiguous) ?
221  MatrixStrideType both_stride_type =
224 
225  int32 input_dim = c.InputDim(),
226  output_dim = c.OutputDim(),
227  num_rows = RandInt(1, 100);
228  CuMatrix<BaseFloat> input_data(num_rows, input_dim, kUndefined,
229  input_stride_type);
230  input_data.SetRandn();
231  CuMatrix<BaseFloat> output_data3(num_rows, input_dim, kSetZero,
232  output_stride_type);
233  output_data3.CopyFromMat(input_data);
235  output_data1(num_rows, output_dim, kSetZero, output_stride_type),
236  output_data2(num_rows, output_dim, kSetZero, output_stride_type);
237  output_data2.Add(1.0);
238 
239  if ((properties & kPropagateAdds) && (properties & kPropagateInPlace)) {
240  KALDI_ERR << "kPropagateAdds and kPropagateInPlace flags are incompatible.";
241  }
242 
243  ResetSeed(rand_seed, c);
244  ComponentPrecomputedIndexes *indexes = GetPrecomputedIndexes(c, num_rows);
245  void *memo = c.Propagate(indexes, input_data, &output_data1);
246 
247  ResetSeed(rand_seed, c);
248  c.DeleteMemo(c.Propagate(indexes, input_data, &output_data2));
249  if (properties & kPropagateInPlace) {
250  ResetSeed(rand_seed, c);
251  c.DeleteMemo(c.Propagate(indexes, output_data3, &output_data3));
252  if (!output_data1.ApproxEqual(output_data3)) {
253  KALDI_ERR << "Test of kPropagateInPlace flag for component of type "
254  << c.Type() << " failed.";
255  }
256  }
257  if (properties & kPropagateAdds)
258  output_data2.Add(-1.0); // remove the offset
259  AssertEqual(output_data1, output_data2);
260 
261 
262  CuMatrix<BaseFloat> output_deriv(num_rows, output_dim, kSetZero, output_stride_type);
263  output_deriv.SetRandn();
264  CuMatrix<BaseFloat> input_deriv1(num_rows, input_dim, kSetZero, input_stride_type),
265  input_deriv2(num_rows, input_dim, kSetZero, input_stride_type);
266  CuMatrix<BaseFloat> input_deriv3(num_rows, output_dim, kSetZero, both_stride_type);
267  input_deriv3.CopyFromMat(output_deriv);
268 
269  input_deriv2.Add(1.0);
270  CuMatrix<BaseFloat> empty_mat;
271 
272  // test with input_deriv1 that's zero
273  c.Backprop("foobar", indexes,
274  ((properties & kBackpropNeedsInput) ? input_data : empty_mat),
275  ((properties & kBackpropNeedsOutput) ? output_data1 : empty_mat),
276  output_deriv,
277  memo,
278  c_copy,
279  &input_deriv1);
280  // test with input_deriv2 that's all ones.
281  c.Backprop("foobar", indexes,
282  ((properties & kBackpropNeedsInput) ? input_data : empty_mat),
283  ((properties & kBackpropNeedsOutput) ? output_data1 : empty_mat),
284  output_deriv,
285  memo,
286  c_copy,
287  &input_deriv2);
288  // test backprop in place, if supported.
289  if (properties & kBackpropInPlace) {
290  c.Backprop("foobar", indexes,
291  ((properties & kBackpropNeedsInput) ? input_data : empty_mat),
292  ((properties & kBackpropNeedsOutput) ? output_data1 : empty_mat),
293  input_deriv3,
294  memo,
295  c_copy,
296  &input_deriv3);
297  }
298  c.DeleteMemo(memo);
299 
300  if (properties & kBackpropAdds)
301  input_deriv2.Add(-1.0); // subtract the offset.
302  AssertEqual(input_deriv1, input_deriv2);
303  if (properties & kBackpropInPlace)
304  AssertEqual(input_deriv1, input_deriv3);
305  delete c_copy;
306  delete indexes;
307 }
308 
310  BaseFloat perturb_delta) {
311  MatrixStrideType input_stride_type = (c.Properties()&kInputContiguous) ?
313  MatrixStrideType output_stride_type = (c.Properties()&kOutputContiguous) ?
315 
316  int32 input_dim = c.InputDim(),
317  output_dim = c.OutputDim(),
318  num_rows = RandInt(1, 100),
319  rand_seed = Rand();
320  int32 properties = c.Properties();
321  CuMatrix<BaseFloat> input_data(num_rows, input_dim, kSetZero, input_stride_type),
322  output_data(num_rows, output_dim, kSetZero, output_stride_type),
323  output_deriv(num_rows, output_dim, kSetZero, output_stride_type);
324  input_data.SetRandn();
325  output_deriv.SetRandn();
326 
327  ResetSeed(rand_seed, c);
328  ComponentPrecomputedIndexes *indexes = GetPrecomputedIndexes(c, num_rows);
329  void *memo = c.Propagate(indexes, input_data, &output_data);
330 
331  CuMatrix<BaseFloat> input_deriv(num_rows, input_dim, kSetZero, input_stride_type),
332  empty_mat;
333  c.Backprop("foobar", indexes,
334  ((properties & kBackpropNeedsInput) ? input_data : empty_mat),
335  ((properties & kBackpropNeedsOutput) ? output_data : empty_mat),
336  output_deriv, memo, NULL, &input_deriv);
337  c.DeleteMemo(memo);
338 
339  int32 test_dim = 3;
340  BaseFloat original_objf = TraceMatMat(output_deriv, output_data, kTrans);
341  Vector<BaseFloat> measured_objf_change(test_dim),
342  predicted_objf_change(test_dim);
343  for (int32 i = 0; i < test_dim; i++) {
344  CuMatrix<BaseFloat> perturbed_input_data(num_rows, input_dim,
345  kSetZero, input_stride_type),
346  perturbed_output_data(num_rows, output_dim,
347  kSetZero, output_stride_type);
348  perturbed_input_data.SetRandn();
349  perturbed_input_data.Scale(perturb_delta);
350  // at this point, perturbed_input_data contains the offset at the input data.
351  predicted_objf_change(i) = TraceMatMat(perturbed_input_data, input_deriv,
352  kTrans);
353  perturbed_input_data.AddMat(1.0, input_data);
354 
355  ResetSeed(rand_seed, c);
356  c.DeleteMemo(c.Propagate(indexes, perturbed_input_data, &perturbed_output_data));
357  measured_objf_change(i) = TraceMatMat(output_deriv, perturbed_output_data,
358  kTrans) - original_objf;
359  }
360  KALDI_LOG << "Predicted objf-change = " << predicted_objf_change;
361  KALDI_LOG << "Measured objf-change = " << measured_objf_change;
362  BaseFloat threshold = 0.1;
363  bool ans = ApproxEqual(predicted_objf_change, measured_objf_change, threshold);
364  if (!ans)
365  KALDI_WARN << "Data-derivative test failed, component-type="
366  << c.Type() << ", input-dim=" << input_dim
367  << ", output-dim=" << output_dim;
368  if (c.Type() == "NormalizeComponent" && input_dim == 1) {
369  // derivatives are mathematically zero, but the measured and predicted
370  // objf have different roundoff and the relative differences are large.
371  // this is not unexpected.
372  KALDI_LOG << "Accepting deriv differences since it is NormalizeComponent "
373  << "with dim=1.";
374  return true;
375  }
376  else if (c.Type() == "ClipGradientComponent") {
377  KALDI_LOG << "Accepting deriv differences since "
378  << "it is ClipGradientComponent.";
379  return true;
380  }
381  delete indexes;
382  return ans;
383 }
384 
385 
386 // if test_derivative == false then the test only tests that the update
387 // direction is downhill. if true, then we measure the actual model-derivative
388 // and check that it's accurate.
389 // returns true on success, false on test failure.
391  BaseFloat perturb_delta,
392  bool test_derivative) {
393  int32 input_dim = c.InputDim(),
394  output_dim = c.OutputDim(),
395  num_rows = RandInt(1, 100);
396  int32 properties = c.Properties();
397  if ((properties & kUpdatableComponent) == 0) {
398  // nothing to test.
399  return true;
400  }
401  MatrixStrideType input_stride_type = (c.Properties()&kInputContiguous) ?
403  MatrixStrideType output_stride_type = (c.Properties()&kOutputContiguous) ?
405 
406  CuMatrix<BaseFloat> input_data(num_rows, input_dim, kSetZero, input_stride_type),
407  output_data(num_rows, output_dim, kSetZero, output_stride_type),
408  output_deriv(num_rows, output_dim, kSetZero, output_stride_type);
409  input_data.SetRandn();
410  output_deriv.SetRandn();
411 
412  ComponentPrecomputedIndexes *indexes = GetPrecomputedIndexes(c, num_rows);
413  void *memo = c.Propagate(indexes, input_data, &output_data);
414 
415  BaseFloat original_objf = TraceMatMat(output_deriv, output_data, kTrans);
416 
417  Component *c_copy = c.Copy();
418 
419  const UpdatableComponent *uc = dynamic_cast<const UpdatableComponent*>(&c);
420  UpdatableComponent *uc_copy = dynamic_cast<UpdatableComponent*>(c_copy);
421  KALDI_ASSERT(uc != NULL && uc_copy != NULL);
422  if (test_derivative) {
423  uc_copy->Scale(0.0);
424  uc_copy->SetAsGradient();
425  }
426 
427  CuMatrix<BaseFloat> input_deriv(num_rows, input_dim,
428  kSetZero, input_stride_type),
429  empty_mat;
430  c.Backprop("foobar", indexes,
431  ((properties & kBackpropNeedsInput) ? input_data : empty_mat),
432  ((properties & kBackpropNeedsOutput) ? output_data : empty_mat),
433  output_deriv, memo, c_copy,
434  (RandInt(0, 1) == 0 ? &input_deriv : NULL));
435  c.DeleteMemo(memo);
436 
437  if (!test_derivative) { // Just testing that the model update is downhill.
438  CuMatrix<BaseFloat> new_output_data(num_rows, output_dim,
439  kSetZero, output_stride_type);
440  c.DeleteMemo(c_copy->Propagate(indexes, input_data, &new_output_data));
441 
442  BaseFloat new_objf = TraceMatMat(output_deriv, new_output_data, kTrans);
443  bool ans = (new_objf > original_objf);
444  if (!ans) {
445  KALDI_WARN << "After update, new objf is not better than the original objf: "
446  << new_objf << " <= " << original_objf;
447  }
448  delete c_copy;
449  delete indexes;
450  return ans;
451  } else {
452  // check that the model derivative is accurate.
453  int32 test_dim = 3;
454 
455  Vector<BaseFloat> measured_objf_change(test_dim),
456  predicted_objf_change(test_dim);
457  for (int32 i = 0; i < test_dim; i++) {
458  CuMatrix<BaseFloat> perturbed_output_data(num_rows, output_dim,
459  kSetZero, output_stride_type);
460  Component *c_perturbed = c.Copy();
461  UpdatableComponent *uc_perturbed =
462  dynamic_cast<UpdatableComponent*>(c_perturbed);
463  KALDI_ASSERT(uc_perturbed != NULL);
464  uc_perturbed->PerturbParams(perturb_delta);
465 
466  predicted_objf_change(i) = uc_copy->DotProduct(*uc_perturbed) -
467  uc_copy->DotProduct(*uc);
468  c_perturbed->Propagate(indexes, input_data, &perturbed_output_data);
469  measured_objf_change(i) = TraceMatMat(output_deriv, perturbed_output_data,
470  kTrans) - original_objf;
471  delete c_perturbed;
472  }
473  KALDI_LOG << "Predicted objf-change = " << predicted_objf_change;
474  KALDI_LOG << "Measured objf-change = " << measured_objf_change;
475  BaseFloat threshold = 0.1;
476 
477  bool ans = ApproxEqual(predicted_objf_change, measured_objf_change,
478  threshold);
479  if (!ans)
480  KALDI_WARN << "Model-derivative test failed, component-type="
481  << c.Type() << ", input-dim=" << input_dim
482  << ", output-dim=" << output_dim;
483  delete c_copy;
484  delete indexes;
485  return ans;
486  }
487 }
488 
489 
491  for (int32 n = 0; n < 200; n++) {
493  KALDI_LOG << c->Info();
500  if (!TestSimpleComponentDataDerivative(*c, 1.0e-04) &&
501  !TestSimpleComponentDataDerivative(*c, 1.0e-03) &&
502  !TestSimpleComponentDataDerivative(*c, 1.0e-05) &&
503  !TestSimpleComponentDataDerivative(*c, 1.0e-06))
504  KALDI_ERR << "Component data-derivative test failed";
505 
506  if (!TestSimpleComponentModelDerivative(*c, 1.0e-04, false) &&
507  !TestSimpleComponentModelDerivative(*c, 1.0e-03, false) &&
508  !TestSimpleComponentModelDerivative(*c, 1.0e-06, false))
509  KALDI_ERR << "Component downhill-update test failed";
510 
511  if (!TestSimpleComponentModelDerivative(*c, 1.0e-04, true) &&
512  !TestSimpleComponentModelDerivative(*c, 1.0e-03, true) &&
513  !TestSimpleComponentModelDerivative(*c, 1.0e-05, true) &&
514  !TestSimpleComponentModelDerivative(*c, 1.0e-06, true))
515  KALDI_ERR << "Component model-derivative test failed";
516 
517  delete c;
518  }
519 }
520 
521 } // namespace nnet3
522 } // namespace kaldi
523 
524 int main() {
525  using namespace kaldi;
526  using namespace kaldi::nnet3;
527 #if HAVE_CUDA == 1
528  kaldi::int32 loop = 0;
529  for (loop = 0; loop < 2; loop++) {
530  //CuDevice::Instantiate().SetDebugStrideMode(true);
531  if (loop == 0)
532  CuDevice::Instantiate().SelectGpuId("no");
533  else
534  CuDevice::Instantiate().SelectGpuId("yes");
535 #endif
537 #if HAVE_CUDA == 1
538  } // No for loop if 'HAVE_CUDA != 1',
539  CuDevice::Instantiate().PrintProfile();
540 #endif
541  KALDI_LOG << "Nnet component tests succeeded.";
542 
543  return 0;
544 }
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:344
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
virtual void Write(std::ostream &os, bool binary) const =0
Write component to stream.
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const =0
Propagate function.
void TestSimpleComponentPropagateProperties(const Component &c)
Abstract base-class for neural-net components.
int main()
virtual Component * Copy() const =0
Copies component (deep copy).
virtual int32 NumParameters() const
The following new virtual function returns the total dimension of the parameters in this class...
This file contains various routines that are useful in test code.
kaldi::int32 int32
virtual int32 OutputDim() const =0
Returns output-dimension of this component.
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
virtual void DeleteMemo(void *memo) const
This virtual function only needs to be overwritten by Components that return a non-NULL memo from the...
void TestNnetComponentAddScale(Component *c)
virtual void Vectorize(VectorBase< BaseFloat > *params) const
Turns the parameters into vector form.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const =0
Backprop function; depending on which of the arguments &#39;to_update&#39; and &#39;in_deriv&#39; are non-NULL...
This file contains declarations of components that are "simple", meaning they don&#39;t care about the in...
static void ResetSeed(int32 rand_seed, const Component &c)
bool TestSimpleComponentModelDerivative(const Component &c, BaseFloat perturb_delta, bool test_derivative)
bool StringsApproxEqual(const std::string &a, const std::string &b, int32 decimal_places_tolerance)
This function returns true when two text strings are approximately equal, and false when they are not...
Definition: text-utils.cc:335
void TestNnetComponentVectorizeUnVectorize(Component *c)
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
virtual BaseFloat DotProduct(const UpdatableComponent &other) const =0
Computes dot-product between parameters of two instances of a Component.
void Add(Real value)
Definition: cu-matrix.cc:582
virtual int32 Properties() const =0
Return bitmask of the component&#39;s properties.
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
struct rnnlm::@11::@12 n
MatrixStrideType
Definition: matrix-common.h:44
bool CheckStringsApproxEqual(const std::string &a, const std::string &b, int32 tolerance=3)
#define KALDI_ERR
Definition: kaldi-error.h:147
static Component * ReadNew(std::istream &is, bool binary)
Read component from stream (works out its type). Dies on error.
#define KALDI_WARN
Definition: kaldi-error.h:150
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
bool TestSimpleComponentDataDerivative(const Component &c, BaseFloat perturb_delta)
void Scale(Real alpha)
Multiplies all elements by this constant.
void TestNnetComponentCopy(Component *c)
int Rand(struct RandomState *state)
Definition: kaldi-math.cc:45
virtual void SetAsGradient()
Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring learning_rate_factor_.
ComponentPrecomputedIndexes * GetPrecomputedIndexes(const Component &c, int32 num_rows)
Class UpdatableComponent is a Component which has trainable parameters; it extends the interface of C...
virtual std::string Type() const =0
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual void PerturbParams(BaseFloat stddev)=0
This function is to be used in testing.
virtual void UnVectorize(const VectorBase< BaseFloat > &params)
Converts the parameters from vector form.
A class representing a vector.
Definition: kaldi-vector.h:406
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
Definition: kaldi-math.h:276
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
void TestNnetComponentUpdatable(Component *c)
virtual int32 InputDim() const =0
Returns input-dimension of this component.
#define KALDI_LOG
Definition: kaldi-error.h:153
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265
void TestNnetComponentIo(Component *c)
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95
Component * GenerateRandomSimpleComponent()
Generates random simple component for testing.