31 if (rand_component != NULL) {
41 int32 tolerance = 3) {
53 bool binary = (
Rand() % 2 == 0);
54 std::ostringstream os1;
55 c->
Write(os1, binary);
56 std::istringstream is(os1.str());
58 std::ostringstream os2;
59 c2->Write(os2, binary);
61 std::string s1 = os1.str(), s2 = os2.str();
70 KALDI_ERR <<
"Expected info strings to be equal: '" 71 << c->
Info() <<
"' vs. '" << c2->
Info() <<
"'";
104 for(
int i = 0;
i < params.Dim();
i++)
115 "Component returns updatable flag but does not inherit " 116 "from UpdatableComponent");
134 *uc3 = dynamic_cast<UpdatableComponent*>(uc->
Copy());
142 KALDI_ERR <<
"Expected info strings to be equal: '" 143 << uc2->Info() <<
"' vs. '" << uc3->Info() <<
"'";
148 uc2->Vectorize(&vec2);
150 uc2->UnVectorize(vec2);
157 uc2->Vectorize(&vec2);
159 uc2->UnVectorize(vec2);
161 uc3->Vectorize(&vec2);
182 std::vector<Index> input_indexes(num_rows);
184 if (num_rows % 3 == 0) { num_t_values = 3; }
185 else if (num_rows % 2 == 0) { num_t_values = 2; }
186 else { num_t_values = 1; }
188 for (
int32 i = 0;
i < num_rows;
i++) {
189 input_indexes[
i].n =
i % num_t_values;
190 input_indexes[
i].x = 0;
191 input_indexes[
i].t =
i / num_t_values;
193 std::vector<Index> output_indexes(input_indexes);
199 bool need_backprop =
true;
235 output_data1(num_rows, output_dim,
kSetZero, output_stride_type),
236 output_data2(num_rows, output_dim,
kSetZero, output_stride_type);
237 output_data2.
Add(1.0);
240 KALDI_ERR <<
"kPropagateAdds and kPropagateInPlace flags are incompatible.";
245 void *memo = c.
Propagate(indexes, input_data, &output_data1);
249 if (properties & kPropagateInPlace) {
252 if (!output_data1.ApproxEqual(output_data3)) {
253 KALDI_ERR <<
"Test of kPropagateInPlace flag for component of type " 254 << c.
Type() <<
" failed.";
257 if (properties & kPropagateAdds)
258 output_data2.
Add(-1.0);
265 input_deriv2(num_rows, input_dim,
kSetZero, input_stride_type);
269 input_deriv2.
Add(1.0);
282 ((properties & kBackpropNeedsInput) ? input_data : empty_mat),
283 ((properties & kBackpropNeedsOutput) ? output_data1 : empty_mat),
291 ((properties & kBackpropNeedsInput) ? input_data : empty_mat),
292 ((properties & kBackpropNeedsOutput) ? output_data1 : empty_mat),
301 input_deriv2.
Add(-1.0);
303 if (properties & kBackpropInPlace)
322 output_data(num_rows, output_dim,
kSetZero, output_stride_type),
323 output_deriv(num_rows, output_dim,
kSetZero, output_stride_type);
324 input_data.SetRandn();
329 void *memo = c.
Propagate(indexes, input_data, &output_data);
336 output_deriv, memo, NULL, &input_deriv);
342 predicted_objf_change(test_dim);
343 for (
int32 i = 0;
i < test_dim;
i++) {
346 perturbed_output_data(num_rows, output_dim,
348 perturbed_input_data.SetRandn();
349 perturbed_input_data.Scale(perturb_delta);
351 predicted_objf_change(
i) =
TraceMatMat(perturbed_input_data, input_deriv,
353 perturbed_input_data.AddMat(1.0, input_data);
357 measured_objf_change(
i) =
TraceMatMat(output_deriv, perturbed_output_data,
360 KALDI_LOG <<
"Predicted objf-change = " << predicted_objf_change;
361 KALDI_LOG <<
"Measured objf-change = " << measured_objf_change;
363 bool ans =
ApproxEqual(predicted_objf_change, measured_objf_change, threshold);
365 KALDI_WARN <<
"Data-derivative test failed, component-type=" 366 << c.
Type() <<
", input-dim=" << input_dim
367 <<
", output-dim=" << output_dim;
368 if (c.
Type() ==
"NormalizeComponent" && input_dim == 1) {
372 KALDI_LOG <<
"Accepting deriv differences since it is NormalizeComponent " 376 else if (c.
Type() ==
"ClipGradientComponent") {
377 KALDI_LOG <<
"Accepting deriv differences since " 378 <<
"it is ClipGradientComponent.";
392 bool test_derivative) {
407 output_data(num_rows, output_dim,
kSetZero, output_stride_type),
408 output_deriv(num_rows, output_dim,
kSetZero, output_stride_type);
409 input_data.SetRandn();
413 void *memo = c.
Propagate(indexes, input_data, &output_data);
422 if (test_derivative) {
433 output_deriv, memo, c_copy,
434 (
RandInt(0, 1) == 0 ? &input_deriv : NULL));
437 if (!test_derivative) {
443 bool ans = (new_objf > original_objf);
445 KALDI_WARN <<
"After update, new objf is not better than the original objf: " 446 << new_objf <<
" <= " << original_objf;
456 predicted_objf_change(test_dim);
457 for (
int32 i = 0;
i < test_dim;
i++) {
466 predicted_objf_change(
i) = uc_copy->
DotProduct(*uc_perturbed) -
468 c_perturbed->
Propagate(indexes, input_data, &perturbed_output_data);
469 measured_objf_change(
i) =
TraceMatMat(output_deriv, perturbed_output_data,
473 KALDI_LOG <<
"Predicted objf-change = " << predicted_objf_change;
474 KALDI_LOG <<
"Measured objf-change = " << measured_objf_change;
477 bool ans =
ApproxEqual(predicted_objf_change, measured_objf_change,
480 KALDI_WARN <<
"Model-derivative test failed, component-type=" 481 << c.
Type() <<
", input-dim=" << input_dim
482 <<
", output-dim=" << output_dim;
504 KALDI_ERR <<
"Component data-derivative test failed";
509 KALDI_ERR <<
"Component downhill-update test failed";
515 KALDI_ERR <<
"Component model-derivative test failed";
525 using namespace kaldi;
529 for (loop = 0; loop < 2; loop++) {
532 CuDevice::Instantiate().SelectGpuId(
"no");
534 CuDevice::Instantiate().SelectGpuId(
"yes");
539 CuDevice::Instantiate().PrintProfile();
541 KALDI_LOG <<
"Nnet component tests succeeded.";
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
virtual void Write(std::ostream &os, bool binary) const =0
Write component to stream.
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const =0
Propagate function.
void TestSimpleComponentPropagateProperties(const Component &c)
Abstract base-class for neural-net components.
virtual Component * Copy() const =0
Copies component (deep copy).
virtual int32 NumParameters() const
The following new virtual function returns the total dimension of the parameters in this class...
This file contains various routines that are useful in test code.
virtual int32 OutputDim() const =0
Returns output-dimension of this component.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
virtual void DeleteMemo(void *memo) const
This virtual function only needs to be overwritten by Components that return a non-NULL memo from the...
void TestNnetComponentAddScale(Component *c)
virtual void Vectorize(VectorBase< BaseFloat > *params) const
Turns the parameters into vector form.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const =0
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
This file contains declarations of components that are "simple", meaning they don't care about the in...
static void ResetSeed(int32 rand_seed, const Component &c)
bool TestSimpleComponentModelDerivative(const Component &c, BaseFloat perturb_delta, bool test_derivative)
bool StringsApproxEqual(const std::string &a, const std::string &b, int32 decimal_places_tolerance)
This function returns true when two text strings are approximately equal, and false when they are not...
void TestNnetComponentVectorizeUnVectorize(Component *c)
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
virtual BaseFloat DotProduct(const UpdatableComponent &other) const =0
Computes dot-product between parameters of two instances of a Component.
virtual int32 Properties() const =0
Return bitmask of the component's properties.
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
void UnitTestNnetComponent()
bool CheckStringsApproxEqual(const std::string &a, const std::string &b, int32 tolerance=3)
static Component * ReadNew(std::istream &is, bool binary)
Read component from stream (works out its type). Dies on error.
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
bool TestSimpleComponentDataDerivative(const Component &c, BaseFloat perturb_delta)
void Scale(Real alpha)
Multiplies all elements by this constant.
void TestNnetComponentCopy(Component *c)
int Rand(struct RandomState *state)
virtual void SetAsGradient()
Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring learning_rate_factor_.
ComponentPrecomputedIndexes * GetPrecomputedIndexes(const Component &c, int32 num_rows)
Class UpdatableComponent is a Component which has trainable parameters; it extends the interface of C...
virtual std::string Type() const =0
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual void PerturbParams(BaseFloat stddev)=0
This function is to be used in testing.
virtual void UnVectorize(const VectorBase< BaseFloat > ¶ms)
Converts the parameters from vector form.
A class representing a vector.
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
#define KALDI_ASSERT(cond)
static void AssertEqual(float a, float b, float relative_tolerance=0.001)
assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
void TestNnetComponentUpdatable(Component *c)
virtual int32 InputDim() const =0
Returns input-dimension of this component.
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
void TestNnetComponentIo(Component *c)
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Component * GenerateRandomSimpleComponent()
Generates random simple component for testing.