78 using namespace kaldi;
81 typedef kaldi::int64 int64;
84 "Using a subset of training or held-out examples, compute the average\n" 85 "over the first n nnet3 models where we maxize the objective function\n" 86 "for n. Note that the order of models has been reversed before\n" 87 "being fed into this binary. So we are actually combining last n models.\n" 88 "Inputs and outputs are 'raw' nnets.\n" 90 "Usage: nnet3-combine [options] <nnet-in1> <nnet-in2> ... <nnet-inN> <valid-examples-in> <nnet-out>\n" 93 " nnet3-combine 1.1.raw 1.2.raw 1.3.raw ark:valid.egs 2.raw\n";
95 bool binary_write =
true;
96 int32 max_objective_evaluations = 30;
97 bool batchnorm_test_mode =
false,
98 dropout_test_mode =
true;
99 std::string use_gpu =
"yes";
102 po.Register(
"binary", &binary_write,
"Write output in binary mode");
103 po.Register(
"max-objective-evaluations", &max_objective_evaluations,
"The " 104 "maximum number of objective evaluations in order to figure " 105 "out the best number of models to combine. It helps to speedup " 106 "if the number of models provided to this binary is quite " 107 "large (e.g. several hundred).");
108 po.Register(
"batchnorm-test-mode", &batchnorm_test_mode,
109 "If true, set test-mode to true on any BatchNormComponents " 110 "while evaluating objectives.");
111 po.Register(
"dropout-test-mode", &dropout_test_mode,
112 "If true, set test-mode to true on any DropoutComponents and " 113 "DropoutMaskComponents while evaluating objectives.");
114 po.Register(
"use-gpu", &use_gpu,
115 "yes|no|optional|wait, only has effect if compiled with CUDA");
119 if (po.NumArgs() < 3) {
125 CuDevice::Instantiate().SelectGpuId(use_gpu);
129 nnet_rxfilename = po.GetArg(1),
130 valid_examples_rspecifier = po.GetArg(po.NumArgs() - 1),
131 nnet_wxfilename = po.GetArg(po.NumArgs());
135 Nnet moving_average_nnet(nnet), best_nnet(nnet);
139 std::vector<NnetExample> egs;
145 valid_examples_rspecifier);
146 for (; !example_reader.Done(); example_reader.Next())
147 egs.push_back(example_reader.Value());
148 KALDI_LOG <<
"Read " << egs.size() <<
" examples.";
153 int32 best_num_to_combine = 1;
155 init_objf =
ComputeObjf(batchnorm_test_mode, dropout_test_mode,
156 egs, moving_average_nnet, &prob_computer),
157 best_objf = init_objf;
158 KALDI_LOG <<
"objective function using the last model is " << init_objf;
160 int32 num_nnets = po.NumArgs() - 2;
163 int32 num_to_add = (num_nnets + max_objective_evaluations - 1) /
164 max_objective_evaluations;
165 for (int32
n = 1;
n < num_nnets;
n++) {
171 if ((
n - 1) % num_to_add == num_to_add - 1 ||
n == num_nnets - 1) {
172 double objf =
ComputeObjf(batchnorm_test_mode, dropout_test_mode,
173 egs, moving_average_nnet, &prob_computer);
175 <<
" models, objective function is " << objf;
176 if (objf > best_objf) {
178 best_nnet = moving_average_nnet;
179 best_num_to_combine =
n + 1;
183 KALDI_LOG <<
"Combining " << best_num_to_combine
184 <<
" nnets, objective function changed from " << init_objf
185 <<
" to " << best_objf;
191 CuDevice::Instantiate().PrintProfile();
195 KALDI_LOG <<
"Finished combining neural nets, wrote model to " 197 }
catch(
const std::exception &e) {
198 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void UpdateNnetMovingAverage(int32 num_models, const Nnet &nnet, Nnet *moving_average_nnet)
This class is for computing cross-entropy and accuracy values in a neural network, for diagnostics.
double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode, const std::vector< NnetExample > &egs, const Nnet &nnet, NnetComputeProb *prob_computer)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void RecomputeStats(const std::vector< NnetChainExample > &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet)
This function zeros the stored component-level stats in the nnet using ZeroComponentStats(), then recomputes them with the supplied egs.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
bool HasBatchnorm(const Nnet &nnet)
Returns true if nnet has at least one component of type BatchNormComponent.
#define KALDI_ASSERT(cond)
void WriteKaldiObject(const C &c, const std::string &filename, bool binary)