29 using namespace kaldi;
32 typedef kaldi::int64 int64;
35 "Given an old and a new 'raw' nnet3 network and some training examples\n" 36 "(possibly held-out), show the average objective function given the\n" 37 "mean of the two networks, and the breakdown by component of why this\n" 38 "happened (computed from derivative information). Also shows parameter\n" 39 "differences per layer. If training examples not provided, only shows\n" 40 "parameter differences per layer.\n" 42 "Usage: nnet3-show-progress [options] <old-net-in> <new-net-in>" 43 " [<training-examples-in>]\n" 44 "e.g.: nnet3-show-progress 1.nnet 2.nnet ark:valid.egs\n";
48 int32 num_segments = 1;
49 std::string use_gpu =
"no";
53 po.Register(
"num-segments", &num_segments,
54 "Number of line segments used for computing derivatives");
55 po.Register(
"use-gpu", &use_gpu,
56 "yes|no|optional|wait, only has effect if compiled with CUDA");
61 if (po.NumArgs() < 2 || po.NumArgs() > 3) {
67 CuDevice::Instantiate().SelectGpuId(use_gpu);
70 std::string nnet1_rxfilename = po.GetArg(1),
71 nnet2_rxfilename = po.GetArg(2),
72 examples_rspecifier = po.GetOptArg(3);
79 KALDI_WARN <<
"Parameter-dim mismatch, cannot show progress.";
83 if (!examples_rspecifier.empty() &&
IsSimpleNnet(nnet1)) {
84 std::vector<NnetExample> examples;
86 for (; !example_reader.Done(); example_reader.Next())
87 examples.push_back(example_reader.Value());
89 int32 num_examples = examples.size();
91 if (num_examples == 0)
97 for (int32 s = 0; s < num_segments; s++) {
99 BaseFloat start = (s + 0.0) / num_segments,
100 end = (s + 1.0) / num_segments, middle = 0.5 * (start + end);
101 Nnet interp_nnet(nnet2);
103 AddNnet(nnet1, 1.0 - middle, &interp_nnet);
106 std::vector<NnetExample>::const_iterator eg_iter = examples.begin(),
107 eg_end = examples.end();
108 for (; eg_iter != eg_end; ++eg_iter)
109 prob_computer.Compute(*eg_iter);
113 prob_computer.PrintTotalStats();
114 const Nnet &nnet_gradient = prob_computer.GetDeriv();
116 <<
", objf per frame is " << objf_per_frame;
121 old_dotprod.Scale(1.0 / objf_info->
tot_weight);
122 new_dotprod.Scale(1.0 / objf_info->
tot_weight);
123 diff.AddVec(1.0/ num_segments, new_dotprod);
124 diff.AddVec(-1.0 / num_segments, old_dotprod);
125 KALDI_VLOG(1) <<
"By segment " << s <<
", objf change is " 128 KALDI_LOG <<
"Total objf change per component is " 133 Nnet diff_nnet(nnet1);
134 AddNnet(nnet2, -1.0, &diff_nnet);
136 KALDI_VLOG(1) <<
"Printing info for the difference between the neural nets: " 142 dot_prod.ApplyPow(0.5);
143 KALDI_LOG <<
"Parameter differences per layer are " 147 new_prod(num_updatable);
150 baseline_prod.ApplyPow(0.5);
151 new_prod.ApplyPow(0.5);
153 KALDI_LOG <<
"Norms of parameter matrices from <new-nnet-in> are " 156 dot_prod.DivElements(baseline_prod);
157 KALDI_LOG <<
"Relative parameter differences per layer are " 161 CuDevice::Instantiate().PrintProfile();
164 }
catch(
const std::exception &e) {
165 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
std::string PrintVectorPerUpdatableComponent(const Nnet &nnet, const VectorBase< BaseFloat > &vec)
This function is for printing, to a string, a vector with one element per updatable component of the ...
void ComponentDotProducts(const Nnet &nnet1, const Nnet &nnet2, VectorBase< BaseFloat > *dot_prod)
Returns dot products between two networks of the same structure (calls the DotProduct functions of th...
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics.
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void Register(OptionsItf *opts)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
A class representing a vector.
int32 NumUpdatableComponents(const Nnet &dest)
Returns the number of updatable components in the nnet.
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).