110 using namespace kaldi;
116 "Perform one iteration of Multi-stream training, truncated BPTT for LSTMs.\n" 117 "The training targets are pdf-posteriors, usually prepared by ali-to-post.\n" 118 "The updates are per-utterance.\n" 120 "Usage: nnet-train-multistream [options] " 121 "<feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n" 122 "e.g.: nnet-train-lstm-streams scp:feature.scp ark:posterior.ark nnet.init nnet.iter1\n";
132 po.Register(
"binary", &binary,
"Write output in binary mode");
134 bool crossvalidate =
false;
135 po.Register(
"cross-validate", &crossvalidate,
136 "Perform cross-validation (don't back-propagate)");
138 std::string feature_transform;
139 po.Register(
"feature-transform", &feature_transform,
140 "Feature transform in Nnet format");
142 std::string objective_function =
"xent";
143 po.Register(
"objective-function", &objective_function,
144 "Objective function : xent|mse");
146 int32 length_tolerance = 5;
147 po.Register(
"length-tolerance", &length_tolerance,
148 "Allowed length difference of features/targets (frames)");
150 std::string frame_weights;
151 po.Register(
"frame-weights", &frame_weights,
152 "Per-frame weights to scale gradients (frame selection/weighting).");
154 int32 batch_size = 20;
155 po.Register(
"batch-size", &batch_size,
156 "Length of 'one stream' in the Multi-stream training");
158 int32 num_streams = 4;
159 po.Register(
"num-streams", &num_streams,
160 "Number of streams in the Multi-stream training");
163 po.Register(
"randomize", &dummy,
"Dummy option.");
165 std::string use_gpu=
"yes";
166 po.Register(
"use-gpu", &use_gpu,
167 "yes|no|optional, only has effect if compiled with CUDA");
171 if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
176 std::string feature_rspecifier = po.GetArg(1),
177 targets_rspecifier = po.GetArg(2),
178 model_filename = po.GetArg(3);
180 std::string target_model_filename;
181 if (!crossvalidate) {
182 target_model_filename = po.GetArg(4);
185 using namespace kaldi;
190 CuDevice::Instantiate().SelectGpuId(use_gpu);
194 if (feature_transform !=
"") {
195 nnet_transf.
Read(feature_transform);
199 nnet.
Read(model_filename);
207 kaldi::int64 total_frames = 0;
212 if (frame_weights !=
"") {
213 weights_reader.
Open(frame_weights);
216 Xent xent(loss_opts);
221 KALDI_LOG << (crossvalidate ?
"CROSS-VALIDATION" :
"TRAINING")
229 std::vector<Matrix<BaseFloat> > feats_utt(num_streams);
230 std::vector<Posterior> labels_utt(num_streams);
231 std::vector<Vector<BaseFloat> > weights_utt(num_streams);
232 std::vector<int32> cursor_utt(num_streams);
233 std::vector<int32> new_utt_flags(num_streams);
241 new_utt_flags.assign(num_streams, 0);
242 for (
int s = 0; s < num_streams; s++) {
244 if (cursor_utt[s] >= feats_utt[s].NumRows()) {
249 if (
ReadData(feature_reader, target_reader, weights_reader,
251 &feats, &targets, &weights,
252 &num_no_tgt_mat, &num_other_error)) {
268 labels_utt[s] = targets;
269 weights_utt[s] = weights;
271 new_utt_flags[s] = 1;
278 size_t inactive_streams = 0;
279 for (int32 s = 0; s < num_streams; s++) {
280 if (feats_utt[s].NumRows() - cursor_utt[s] <= 0) {
281 inactive_streams += 1;
284 if (inactive_streams >= 1) {
285 KALDI_LOG <<
"No more data to re-fill one of the streams, end of the training!";
286 KALDI_LOG <<
"(remaining stubs of data are discarded, don't overtrain on them)";
291 std::vector<int32> frame_num_utt;
299 int32 n_streams = num_streams;
303 target_host.resize(n_streams * batch_size);
305 frame_num_utt.resize(n_streams, 0);
308 for (int32 s = 0; s < n_streams; s++) {
309 int32 num_rows = std::max(0, feats_utt[s].NumRows() - cursor_utt[s]);
310 frame_num_utt[s] = std::min(batch_size, num_rows);
315 for (int32 s = 0; s < n_streams; s++) {
316 if (frame_num_utt[s] > 0) {
317 auto mat_tmp = feats_utt[s].RowRange(cursor_utt[s], frame_num_utt[s]);
318 for (int32 r = 0; r < frame_num_utt[s]; r++) {
319 feat_mat_host.
Row(r*n_streams + s).CopyFromVec(mat_tmp.Row(r));
324 for (int32 s = 0; s < n_streams; s++) {
325 for (int32 r = 0; r < frame_num_utt[s]; r++) {
326 target_host[r*n_streams + s] = labels_utt[s][cursor_utt[s] + r];
331 for (int32 s = 0; s < n_streams; s++) {
332 if (frame_num_utt[s] > 0) {
333 auto weight_tmp = weights_utt[s].Range(cursor_utt[s], frame_num_utt[s]);
334 for (int32 r = 0; r < frame_num_utt[s]; r++) {
335 weight_host(r*n_streams + s) = weight_tmp(r);
342 for (int32 s = 0; s < n_streams; s++) {
343 cursor_utt[s] += frame_num_utt[s];
354 std::ostringstream os;
356 for (
size_t i = 0;
i < cursor_utt.size();
i++) {
357 os << cursor_utt[
i] <<
" ";
360 KALDI_LOG <<
"cursor_utt[" << cursor_utt.size() <<
"]" << os.str();
364 std::ostringstream os;
366 for (
size_t i = 0;
i < frame_num_utt.size();
i++) {
367 os << frame_num_utt[
i] <<
" ";
370 KALDI_LOG <<
"frame_num_utt[" << frame_num_utt.size() <<
"]" << os.str();
382 if (objective_function ==
"xent") {
383 xent.Eval(weight_host, nnet_out, target_host, &obj_diff);
384 }
else if (objective_function ==
"mse") {
385 mse.Eval(weight_host, nnet_out, target_host, &obj_diff);
387 KALDI_ERR <<
"Unknown objective function code : " 388 << objective_function;
391 if (!crossvalidate) {
398 if (total_frames == 0) {
399 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
402 if (!crossvalidate) {
408 kaldi::int64 tmp_frames = total_frames;
410 num_done += std::accumulate(new_utt_flags.begin(), new_utt_flags.end(), 0);
411 total_frames += std::accumulate(frame_num_utt.begin(), frame_num_utt.end(), 0);
417 if (tmp_frames / F != total_frames / F) {
418 KALDI_VLOG(2) <<
"### After " << total_frames <<
" frames,";
421 if (!crossvalidate) {
430 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
433 if (!crossvalidate) {
438 if (!crossvalidate) {
439 nnet.
Write(target_model_filename, binary);
442 if (objective_function ==
"xent") {
446 KALDI_LOG <<
"Done " << num_done <<
" files, " 447 << num_no_tgt_mat <<
" with no tgt_mats, " 448 << num_other_error <<
" with other errors. " 449 <<
"[" << (crossvalidate ?
"CROSS-VALIDATION" :
"TRAINING")
450 <<
", " << time.
Elapsed() / 60 <<
" min, processing " 451 << total_frames / time.
Elapsed() <<
" frames per sec, " 452 <<
"GPU_time " << 100.*time_gpu/time.
Elapsed() <<
"% ]";
454 if (objective_function ==
"xent") {
456 }
else if (objective_function ==
"mse") {
459 KALDI_ERR <<
"Unknown objective function code : " << objective_function;
463 CuDevice::Instantiate().PrintProfile();
467 }
catch(
const std::exception &e) {
468 std::cerr << e.what();
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Reset streams in multi-stream training,.
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
Set sequence length in LSTM multi-stream training,.
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
int32 InputDim() const
Dimensionality on network input (input feature dim.),.
bool Open(const std::string &rspecifier)
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Allows random access to a collection of objects in an archive or script file; see The Table concept...
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Register(OptionsItf *opts)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
A class representing a vector.
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
std::string Info() const
Create string with human readable description of the nnet,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
void Register(OptionsItf *opts)
double Elapsed() const
Returns time in seconds.
bool ReadData(SequentialBaseFloatMatrixReader &feature_reader, RandomAccessPosteriorReader &target_reader, RandomAccessBaseFloatVectorReader &weights_reader, int32 length_tolerance, Matrix< BaseFloat > *feats, Posterior *targets, Vector< BaseFloat > *weights, int32 *num_no_tgt_mat, int32 *num_other_error)