34 int main(
int argc,
char *argv[]) {
35 using namespace kaldi;
41 "Perform one iteration of Multi-stream training, per-utterance BPTT for (B)LSTMs.\n" 42 "The updates are done per-utterance, while several utterances are \n" 43 "processed at the same time.\n" 45 "Usage: nnet-train-multistream-perutt [options] <feature-rspecifier> <labels-rspecifier> <model-in> [<model-out>]\n" 46 "e.g.: nnet-train-blstm-streams scp:feats.scp ark:targets.ark nnet.init nnet.iter1\n";
57 po.
Register(
"binary", &binary,
"Write model in binary mode");
59 bool crossvalidate =
false;
60 po.
Register(
"cross-validate", &crossvalidate,
61 "Perform cross-validation (no backpropagation)");
63 std::string feature_transform;
64 po.
Register(
"feature-transform", &feature_transform,
65 "Feature transform in Nnet format");
67 int32 length_tolerance = 5;
68 po.
Register(
"length-tolerance", &length_tolerance,
69 "Allowed length difference of features/targets (frames)");
71 std::string frame_weights;
72 po.
Register(
"frame-weights", &frame_weights,
73 "Per-frame weights to scale gradients (frame selection/weighting).");
75 int32 num_streams = 20;
76 po.
Register(
"num-streams", &num_streams,
77 "Number of sentences processed in parallel (can be lower if sentences are long)");
79 double max_frames = 8000;
80 po.
Register(
"max-frames", &max_frames,
81 "Max number of frames to be processed");
84 po.
Register(
"randomize", &dummy,
"Dummy option.");
86 std::string use_gpu =
"yes";
88 "yes|no|optional, only has effect if compiled with CUDA");
92 if (po.
NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
97 std::string feature_rspecifier = po.
GetArg(1),
98 targets_rspecifier = po.
GetArg(2),
99 model_filename = po.
GetArg(3);
101 std::string target_model_filename;
102 if (!crossvalidate) {
103 target_model_filename = po.
GetArg(4);
106 using namespace kaldi;
111 CuDevice::Instantiate().SelectGpuId(use_gpu);
115 if ( feature_transform !=
"" ) {
116 nnet_transf.
Read(feature_transform);
120 nnet.
Read(model_filename);
128 kaldi::int64 total_frames = 0;
134 if (frame_weights !=
"") {
135 weights_reader.
Open(frame_weights);
139 Xent xent(loss_opts);
144 KALDI_LOG << (crossvalidate ?
"CROSS-VALIDATION" :
"TRAINING")
149 matrix_buffer.
Init(&feature_reader);
155 while (!matrix_buffer.Done()) {
158 std::vector<Matrix<BaseFloat> > feats_utt;
159 std::vector<Posterior> labels_utt;
160 std::vector<Vector<BaseFloat> > weights_utt;
161 std::vector<int32> frame_num_utt;
163 matrix_buffer.ResetLength();
164 for (matrix_buffer.Next(); !matrix_buffer.Done(); matrix_buffer.Next()) {
165 std::string utt = matrix_buffer.Key();
167 if (!targets_reader.
HasKey(utt)) {
173 if (frame_weights !=
"" && !weights_reader.
HasKey(utt)) {
174 KALDI_WARN << utt <<
", missing frame-weights";
184 if (mat.
NumRows() > max_frames)
continue;
187 if (frame_weights !=
"") {
188 weights = weights_reader.
Value(utt);
197 std::vector<int32> length;
198 length.push_back(mat.
NumRows());
199 length.push_back(targets.size());
200 length.push_back(weights.
Dim());
202 int32 min = *std::min_element(length.begin(), length.end());
203 int32 max = *std::max_element(length.begin(), length.end());
205 if (max - min < length_tolerance) {
207 if (targets.size() != min) targets.resize(min);
210 KALDI_WARN <<
"Length mismatch! Targets " << targets.size()
211 <<
", features " << mat.
NumRows() <<
", " << utt;
222 labels_utt.push_back(targets);
223 weights_utt.push_back(weights);
224 frame_num_utt.push_back(feats_transf.
NumRows());
226 if (frame_num_utt.size() == num_streams)
break;
229 int32 max = (*std::max_element(frame_num_utt.begin(), frame_num_utt.end()));
230 if (max * (frame_num_utt.size() + 1) > max_frames)
break;
234 if (frame_num_utt.size() == 0)
continue;
242 int32 n_streams = frame_num_utt.size();
243 int32 frame_num_padded = (*std::max_element(frame_num_utt.begin(), frame_num_utt.end()));
244 int32 feat_dim = feats_utt.front().NumCols();
248 feat_mat_host.
Resize(n_streams * frame_num_padded, feat_dim,
kSetZero);
249 target_host.resize(n_streams * frame_num_padded);
252 for (int32 s = 0; s < n_streams; s++) {
254 for (int32 r = 0; r < frame_num_utt[s]; r++) {
255 feat_mat_host.
Row(r*n_streams + s).CopyFromVec(mat_tmp.
Row(r));
259 for (int32 s = 0; s < n_streams; s++) {
260 const Posterior& target_tmp = labels_utt[s];
261 for (int32 r = 0; r < frame_num_utt[s]; r++) {
262 target_host[r*n_streams + s] = target_tmp[r];
267 for (int32 s = 0; s < n_streams; s++) {
269 for (int32 r = 0; r < frame_num_utt[s]; r++) {
270 weight_host(r*n_streams + s) = weight_tmp(r);
279 std::ostringstream os;
281 for (
size_t i = 0;
i < frame_num_utt.size();
i++) {
282 os << frame_num_utt[
i] <<
" ";
285 KALDI_LOG <<
"frame_num_utt[" << frame_num_utt.size() <<
"]" << os.str();
288 nnet.
ResetStreams(std::vector<int32>(frame_num_utt.size(), 1));
294 xent.
Eval(weight_host, nnet_out, target_host, &obj_diff);
297 if (!crossvalidate) {
302 if (total_frames == 0) {
303 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
306 if (!crossvalidate) {
312 kaldi::int64 tmp_frames = total_frames;
314 num_done += frame_num_utt.size();
315 total_frames += std::accumulate(frame_num_utt.begin(), frame_num_utt.end(), 0);
321 if (tmp_frames / F != total_frames / F) {
322 KALDI_VLOG(2) <<
"### After " << total_frames <<
" frames,";
325 if (!crossvalidate) {
334 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
337 if (!crossvalidate) {
342 if (!crossvalidate) {
343 nnet.
Write(target_model_filename, binary);
347 KALDI_LOG <<
"Done " << num_done <<
" files, " << num_no_tgt_mat
348 <<
" with no tgt_mats, " << num_other_error
349 <<
" with other errors. " 350 <<
"[" << (crossvalidate ?
"CROSS-VALIDATION" :
"TRAINING")
351 <<
", " << time.
Elapsed() / 60 <<
" min, " 352 <<
"fps" << total_frames / time.
Elapsed() <<
"]";
356 CuDevice::Instantiate().PrintProfile();
359 }
catch(
const std::exception &e) {
360 std::cerr << e.what();
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
void Init(SequentialBaseFloatMatrixReader *reader, MatrixBufferOptions opts=MatrixBufferOptions())
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ResetStreams(const std::vector< int32 > &stream_reset_flag)
Reset streams in multi-stream training,.
int main(int argc, char *argv[])
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
Set sequence length in LSTM multi-stream training,.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
bool Open(const std::string &rspecifier)
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
const T & Value(const std::string &key)
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate cross entropy using target-matrix (supports soft labels),.
std::string Report()
Generate string with error report,.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Register(OptionsItf *opts)
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
A buffer for caching (utterance-key, feature-matrix) pairs.
bool HasKey(const std::string &key)
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
void Set(Real f)
Set all members of a vector to a specified value.
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
std::string Info() const
Create string with human readable description of the nnet,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
std::string ReportPerClass()
Generate string with per-class error report,.
void Register(OptionsItf *opts)
MatrixIndexT NumRows() const
Dimensions.
double Elapsed() const
Returns time in seconds.