35 double ComputeObjf(
bool batchnorm_test_mode,
bool dropout_test_mode,
36 const std::vector<NnetExample> &egs,
const Nnet &nnet,
38 if (batchnorm_test_mode || dropout_test_mode) {
40 if (batchnorm_test_mode)
42 if (dropout_test_mode)
46 return ComputeObjf(
false,
false, egs, nnet_copy, &prob_computer_test);
48 prob_computer->
Reset();
49 std::vector<NnetExample>::const_iterator iter = egs.begin(),
51 for (; iter != end; ++iter)
57 if (!(tot_objf == tot_objf && tot_objf - tot_objf == 0))
58 return -std::numeric_limits<double>::infinity();
60 return tot_objf / tot_weights;
67 const Nnet &nnet,
Nnet *moving_average_nnet) {
69 ScaleNnet((num_models - 1.0) / num_models, moving_average_nnet);
70 AddNnet(nnet, 1.0 / num_models, moving_average_nnet);
76 int main(
int argc,
char *argv[]) {
78 using namespace kaldi;
81 typedef kaldi::int64 int64;
84 "Using a subset of training or held-out examples, compute the average\n" 85 "over the first n nnet3 models where we maxize the objective function\n" 86 "for n. Note that the order of models has been reversed before\n" 87 "being fed into this binary. So we are actually combining last n models.\n" 88 "Inputs and outputs are 'raw' nnets.\n" 90 "Usage: nnet3-combine [options] <nnet-in1> <nnet-in2> ... <nnet-inN> <valid-examples-in> <nnet-out>\n" 93 " nnet3-combine 1.1.raw 1.2.raw 1.3.raw ark:valid.egs 2.raw\n";
95 bool binary_write =
true;
96 int32 max_objective_evaluations = 30;
97 bool batchnorm_test_mode =
false,
98 dropout_test_mode =
true;
99 std::string use_gpu =
"yes";
102 po.
Register(
"binary", &binary_write,
"Write output in binary mode");
103 po.
Register(
"max-objective-evaluations", &max_objective_evaluations,
"The " 104 "maximum number of objective evaluations in order to figure " 105 "out the best number of models to combine. It helps to speedup " 106 "if the number of models provided to this binary is quite " 107 "large (e.g. several hundred).");
108 po.
Register(
"batchnorm-test-mode", &batchnorm_test_mode,
109 "If true, set test-mode to true on any BatchNormComponents " 110 "while evaluating objectives.");
111 po.
Register(
"dropout-test-mode", &dropout_test_mode,
112 "If true, set test-mode to true on any DropoutComponents and " 113 "DropoutMaskComponents while evaluating objectives.");
115 "yes|no|optional|wait, only has effect if compiled with CUDA");
125 CuDevice::Instantiate().SelectGpuId(use_gpu);
129 nnet_rxfilename = po.
GetArg(1),
135 Nnet moving_average_nnet(nnet), best_nnet(nnet);
139 std::vector<NnetExample> egs;
145 valid_examples_rspecifier);
146 for (; !example_reader.
Done(); example_reader.
Next())
147 egs.push_back(example_reader.
Value());
148 KALDI_LOG <<
"Read " << egs.size() <<
" examples.";
153 int32 best_num_to_combine = 1;
155 init_objf =
ComputeObjf(batchnorm_test_mode, dropout_test_mode,
156 egs, moving_average_nnet, &prob_computer),
157 best_objf = init_objf;
158 KALDI_LOG <<
"objective function using the last model is " << init_objf;
160 int32 num_nnets = po.
NumArgs() - 2;
163 int32 num_to_add = (num_nnets + max_objective_evaluations - 1) /
164 max_objective_evaluations;
165 for (int32
n = 1;
n < num_nnets;
n++) {
171 if ((
n - 1) % num_to_add == num_to_add - 1 ||
n == num_nnets - 1) {
172 double objf =
ComputeObjf(batchnorm_test_mode, dropout_test_mode,
173 egs, moving_average_nnet, &prob_computer);
175 <<
" models, objective function is " << objf;
176 if (objf > best_objf) {
178 best_nnet = moving_average_nnet;
179 best_num_to_combine =
n + 1;
183 KALDI_LOG <<
"Combining " << best_num_to_combine
184 <<
" nnets, objective function changed from " << init_objf
185 <<
" to " << best_objf;
191 CuDevice::Instantiate().PrintProfile();
195 KALDI_LOG <<
"Finished combining neural nets, wrote model to " 197 }
catch(
const std::exception &e) {
198 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
void UpdateNnetMovingAverage(int32 num_models, const Nnet &nnet, Nnet *moving_average_nnet)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Compute(const NnetExample &eg)
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics.
double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode, const std::vector< NnetExample > &egs, const Nnet &nnet, NnetComputeProb *prob_computer)
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
This file contains some miscellaneous functions dealing with class Nnet.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
int32 NumParameters(const Nnet &src)
Returns the total of the number of parameters in the updatable components of the nnet.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
bool HasBatchnorm(const Nnet &nnet)
Returns true if nnet has at least one component of type BatchNormComponent.
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
double GetTotalObjective(double *tot_weight) const
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)
int main(int argc, char *argv[])
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).