29 using namespace kaldi;
32 typedef kaldi::int64 int64;
35 "Given an old and a new model and some training examples (possibly held-out),\n" 36 "show the average objective function given the mean of the two models,\n" 37 "and the breakdown by component of why this happened (computed from\n" 38 "derivative information). Also shows parameter differences per layer.\n" 39 "If training examples not provided, only shows parameter differences per\n" 42 "Usage: nnet-show-progress [options] <old-model-in> <new-model-in> [<training-examples-in>]\n" 43 "e.g.: nnet-show-progress 1.nnet 2.nnet ark:valid.egs\n";
47 int32 num_segments = 1;
48 int32 batch_size = 1024;
49 std::string use_gpu =
"optional";
51 po.Register(
"num-segments", &num_segments,
52 "Number of line segments used for computing derivatives");
53 po.Register(
"use-gpu", &use_gpu,
54 "yes|no|optional|wait, only has effect if compiled with CUDA");
58 if (po.NumArgs() < 2 || po.NumArgs() > 3) {
64 CuDevice::Instantiate().SelectGpuId(use_gpu);
67 std::string nnet1_rxfilename = po.GetArg(1),
68 nnet2_rxfilename = po.GetArg(2),
69 examples_rspecifier = po.GetOptArg(3);
75 Input ki(nnet1_rxfilename, &binary_read);
76 trans_model.
Read(ki.Stream(), binary_read);
77 am_nnet1.
Read(ki.Stream(), binary_read);
81 Input ki(nnet2_rxfilename, &binary_read);
82 trans_model.
Read(ki.Stream(), binary_read);
83 am_nnet2.
Read(ki.Stream(), binary_read);
88 KALDI_WARN <<
"Parameter-dim mismatch, cannot show progress.";
94 if (!examples_rspecifier.empty()) {
96 const bool treat_as_gradient =
true;
97 nnet_gradient.
SetZero(treat_as_gradient);
99 std::vector<NnetExample> examples;
101 for (; !example_reader.Done(); example_reader.Next())
102 examples.push_back(example_reader.Value());
104 int32 num_examples = examples.size();
109 for (int32 s = 0; s < num_segments; s++) {
111 BaseFloat start = (s + 0.0) / num_segments,
112 end = (s + 1.0) / num_segments, middle = 0.5 * (start + end);
114 interp_nnet.
Scale(middle);
115 interp_nnet.AddNnet(1.0 - middle, am_nnet1.
GetNnet());
118 const bool treat_as_gradient =
true;
119 nnet_gradient.
SetZero(treat_as_gradient);
122 batch_size, &nnet_gradient);
123 KALDI_LOG <<
"At position " << middle <<
", objf per frame is " << objf_per_frame;
126 nnet_gradient.ComponentDotProducts(am_nnet1.
GetNnet(), &old_dotprod);
127 nnet_gradient.ComponentDotProducts(am_nnet2.
GetNnet(), &new_dotprod);
128 old_dotprod.Scale(1.0 / num_examples);
129 new_dotprod.Scale(1.0 / num_examples);
130 diff.AddVec(1.0/ num_segments, new_dotprod);
131 diff.AddVec(-1.0 / num_segments, old_dotprod);
132 KALDI_VLOG(1) <<
"By segment " << s <<
", objf change is " << diff;
134 KALDI_LOG <<
"Total objf change per component is " << diff;
135 if (num_examples == 0) ret = 1;
141 int32 num_updatable = diff_nnet.NumUpdatableComponents();
143 diff_nnet.ComponentDotProducts(diff_nnet, &dot_prod);
144 dot_prod.ApplyPow(0.5);
145 KALDI_LOG <<
"Parameter differences per layer are " 151 baseline_prod.ApplyPow(0.5);
152 dot_prod.DivElements(baseline_prod);
153 KALDI_LOG <<
"Relative parameter differences per layer are " 158 }
catch(
const std::exception &e) {
159 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void AddNnet(const VectorBase< BaseFloat > &scales, const Nnet &other)
For each updatatable component, adds to it the corresponding element of "other" times the appropriate...
int32 NumUpdatableComponents() const
Returns the number of updatable components.
double ComputeNnetGradient(const Nnet &nnet, const std::vector< NnetExample > &validation_set, int32 batch_size, Nnet *gradient)
ComputeNnetGradient is mostly used to compute gradients on validation sets; it divides the example in...
void ComponentDotProducts(const Nnet &other, VectorBase< BaseFloat > *dot_prod) const
void Read(std::istream &is, bool binary)
virtual int32 GetParameterDim() const
void Scale(BaseFloat scale)
Scales all the Components with the same scale.
void SetZero(bool treat_as_gradient)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
A class representing a vector.
const Nnet & GetNnet() const