30 using namespace kaldi;
36 "Perform one iteration (epoch) of Neural Network training with\n" 37 "mini-batch Stochastic Gradient Descent. The training targets\n" 38 "are usually pdf-posteriors, prepared by ali-to-post.\n" 39 "Usage: nnet-train-frmshuff [options] <feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n" 40 "e.g.: nnet-train-frmshuff scp:feats.scp ark:posterior.ark nnet.init nnet.iter1\n";
52 po.Register(
"binary", &binary,
"Write output in binary mode");
54 bool crossvalidate =
false;
55 po.Register(
"cross-validate", &crossvalidate,
56 "Perform cross-validation (don't back-propagate)");
58 bool randomize =
true;
59 po.Register(
"randomize", &randomize,
60 "Perform the frame-level shuffling within the Cache::");
62 std::string feature_transform;
63 po.Register(
"feature-transform", &feature_transform,
64 "Feature transform in Nnet format");
66 std::string objective_function =
"xent";
67 po.Register(
"objective-function", &objective_function,
68 "Objective function : xent|mse|multitask");
70 int32 max_frames = 360000;
71 po.Register(
"max-frames", &max_frames,
72 "Maximum number of frames an utterance can have (skipped if longer)");
74 int32 length_tolerance = 5;
75 po.Register(
"length-tolerance", &length_tolerance,
76 "Allowed length mismatch of features/targets/weights " 77 "(in frames, we truncate to the shortest)");
79 std::string frame_weights;
80 po.Register(
"frame-weights", &frame_weights,
81 "Per-frame weights, used to re-scale gradients.");
83 std::string utt_weights;
84 po.Register(
"utt-weights", &utt_weights,
85 "Per-utterance weights, used to re-scale frame-weights.");
87 std::string use_gpu=
"yes";
88 po.Register(
"use-gpu", &use_gpu,
89 "yes|no|optional, only has effect if compiled with CUDA");
93 if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
98 std::string feature_rspecifier = po.GetArg(1),
99 targets_rspecifier = po.GetArg(2),
100 model_filename = po.GetArg(3);
102 std::string target_model_filename;
103 if (!crossvalidate) {
104 target_model_filename = po.GetArg(4);
107 using namespace kaldi;
112 CuDevice::Instantiate().SelectGpuId(use_gpu);
116 if (feature_transform !=
"") {
117 nnet_transf.
Read(feature_transform);
121 nnet.
Read(model_filename);
129 kaldi::int64 total_frames = 0;
134 if (frame_weights !=
"") {
135 weights_reader.
Open(frame_weights);
138 if (utt_weights !=
"") {
139 utt_weights_reader.
Open(utt_weights);
147 Xent xent(loss_opts);
151 if (0 == objective_function.compare(0, 9,
"multitask")) {
157 multitask.InitFromString(objective_function);
163 KALDI_LOG << (crossvalidate ?
"CROSS-VALIDATION" :
"TRAINING")
170 double time_io_accu = 0.0;
173 while (!feature_reader.Done()) {
176 CuDevice::Instantiate().CheckGpuHealth();
180 for ( ; !feature_reader.Done(); feature_reader.Next()) {
181 if (feature_randomizer.IsFull()) {
186 std::string utt = feature_reader.Key();
189 if (!targets_reader.HasKey(utt)) {
195 if (frame_weights !=
"" && !weights_reader.
HasKey(utt)) {
196 KALDI_WARN << utt <<
", missing per-frame weights";
201 if (utt_weights !=
"" && !utt_weights_reader.
HasKey(utt)) {
202 KALDI_WARN << utt <<
", missing per-utterance weight";
208 Posterior targets = targets_reader.Value(utt);
211 if (frame_weights !=
"") {
212 weights = weights_reader.
Value(utt);
218 if (utt_weights !=
"") {
221 if (w == 0.0)
continue;
226 time_io_accu += time_io.
Elapsed();
230 if (mat.
NumRows() > max_frames) {
231 KALDI_WARN <<
"Utterance too long, skipping! " << utt
232 <<
" (length " << mat.
NumRows() <<
", max_frames " 233 << max_frames <<
")";
241 std::vector<int32> length;
242 length.push_back(mat.
NumRows());
243 length.push_back(targets.size());
244 length.push_back(weights.
Dim());
246 int32 min = *std::min_element(length.begin(), length.end());
247 int32 max = *std::max_element(length.begin(), length.end());
249 if (max - min < length_tolerance) {
252 if (targets.size() != min) targets.resize(min);
255 KALDI_WARN <<
"Length mismatch! Targets " << targets.size()
256 <<
", features " << mat.
NumRows() <<
", " << utt;
269 if (weight_min == 0.0) {
271 std::vector<MatrixIndexT> keep_frames;
272 for (int32
i = 0;
i < weights.
Dim();
i++) {
273 if (weights(
i) > 0.0) {
274 keep_frames.push_back(
i);
279 if (keep_frames.size() == 0)
continue;
284 tmp_feats.Swap(&feats_transf);
288 for (int32
i = 0;
i < keep_frames.size();
i++) {
289 tmp_targets.push_back(targets[keep_frames[
i]]);
291 tmp_targets.swap(targets);
295 for (int32
i = 0;
i < keep_frames.size();
i++) {
296 tmp_weights(
i) = weights(keep_frames[
i]);
298 tmp_weights.Swap(&weights);
304 feature_randomizer.AddData(feats_transf);
305 targets_randomizer.AddData(targets);
306 weights_randomizer.AddData(weights);
313 if (!crossvalidate && randomize) {
314 const std::vector<int32>& mask =
315 randomizer_mask.Generate(feature_randomizer.NumFrames());
316 feature_randomizer.Randomize(mask);
317 targets_randomizer.Randomize(mask);
318 weights_randomizer.Randomize(mask);
322 for ( ; !feature_randomizer.Done(); feature_randomizer.Next(),
323 targets_randomizer.Next(),
324 weights_randomizer.Next()) {
327 const Posterior& nnet_tgt = targets_randomizer.Value();
334 if (objective_function ==
"xent") {
336 xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
337 }
else if (objective_function ==
"mse") {
339 mse.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
340 }
else if (0 == objective_function.compare(0, 9,
"multitask")) {
342 multitask.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
344 KALDI_ERR <<
"Unknown objective function code : " << objective_function;
347 if (!crossvalidate) {
353 if (total_frames == 0) {
354 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
356 if (!crossvalidate) {
365 static int32 counter = 0;
368 if (counter >= 25000) {
369 KALDI_VLOG(2) <<
"### After " << total_frames <<
" frames,";
371 if (!crossvalidate) {
379 total_frames += nnet_in.
NumRows();
384 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
386 if (!crossvalidate) {
391 if (!crossvalidate) {
392 nnet.
Write(target_model_filename, binary);
395 KALDI_LOG <<
"Done " << num_done <<
" files, " 396 << num_no_tgt_mat <<
" with no tgt_mats, " 397 << num_other_error <<
" with other errors. " 398 <<
"[" << (crossvalidate ?
"CROSS-VALIDATION" :
"TRAINING")
399 <<
", " << (randomize ?
"RANDOMIZED" :
"NOT-RANDOMIZED")
400 <<
", " << time.
Elapsed() / 60 <<
" min, processing " 401 << total_frames / time.
Elapsed() <<
" frames per sec;" 402 <<
" i/o time " << 100.*time_io_accu/time.
Elapsed() <<
"%]";
404 if (objective_function ==
"xent") {
407 }
else if (objective_function ==
"mse") {
409 }
else if (0 == objective_function.compare(0, 9,
"multitask")) {
412 KALDI_ERR <<
"Unknown objective function code : " << objective_function;
416 CuDevice::Instantiate().PrintProfile();
420 }
catch(
const std::exception &e) {
421 std::cerr << e.what();
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
Generates randomly ordered vector of indices,.
bool Open(const std::string &rspecifier)
Randomizes elements of a vector according to a mask.
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Configuration variables that affect how frame-level shuffling is done.
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const T & Value(const std::string &key)
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
void CopyRows(const CuMatrixBase< Real > &src, const CuArrayBase< MatrixIndexT > &indexes)
Copies row r from row indexes[r] of src.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Register(OptionsItf *opts)
MatrixIndexT Dim() const
Returns the dimension of the vector.
void Scale(Real alpha)
Multiplies all elements by this constant.
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Shuffles rows of a matrix according to the indices in the mask,.
Matrix for CUDA computing.
MatrixIndexT NumCols() const
A class representing a vector.
Class CuArray represents a vector of an integer or struct of type T.
#define KALDI_ASSERT(cond)
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
void Set(Real f)
Set all members of a vector to a specified value.
void Register(OptionsItf *opts)
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
Randomizes elements of a vector according to a mask.
void Register(OptionsItf *opts)
MatrixIndexT NumRows() const
Dimensions.
double Elapsed() const
Returns time in seconds.