29 int main(
int argc,
char *argv[]) {
30 using namespace kaldi;
36 "Perform one iteration of NN training by SGD with per-utterance updates.\n" 37 "The training targets are represented as pdf-posteriors, usually prepared " 39 "Usage: nnet-train-perutt [options] " 40 "<feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n" 41 "e.g.: nnet-train-perutt scp:feature.scp ark:posterior.ark nnet.init nnet.iter1\n";
51 po.
Register(
"binary", &binary,
"Write output in binary mode");
53 bool crossvalidate =
false;
54 po.
Register(
"cross-validate", &crossvalidate,
55 "Perform cross-validation (don't backpropagate)");
57 std::string feature_transform;
58 po.
Register(
"feature-transform", &feature_transform,
59 "Feature transform in Nnet format");
61 std::string objective_function =
"xent";
62 po.
Register(
"objective-function", &objective_function,
63 "Objective function : xent|mse");
65 int32 length_tolerance = 5;
66 po.
Register(
"length-tolerance", &length_tolerance,
67 "Allowed length difference of features/targets (frames)");
69 std::string frame_weights;
70 po.
Register(
"frame-weights", &frame_weights,
71 "Per-frame weights to scale gradients (frame selection/weighting).");
74 po.
Register(
"max-frames",&max_frames,
"Maximum number of frames a segment can have to be processed");
76 std::string use_gpu=
"yes";
78 "yes|no|optional, only has effect if compiled with CUDA");
81 bool randomize =
false;
83 "Dummy, for compatibility with 'steps/nnet/train_scheduler.sh'");
88 if (po.
NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
93 std::string feature_rspecifier = po.
GetArg(1),
94 targets_rspecifier = po.
GetArg(2),
95 model_filename = po.
GetArg(3);
97 std::string target_model_filename;
99 target_model_filename = po.
GetArg(4);
102 using namespace kaldi;
107 CuDevice::Instantiate().SelectGpuId(use_gpu);
111 if (feature_transform !=
"") {
112 nnet_transf.
Read(feature_transform);
116 nnet.
Read(model_filename);
124 kaldi::int64 total_frames = 0;
129 if (frame_weights !=
"") {
130 weights_reader.
Open(frame_weights);
133 Xent xent(loss_opts);
137 if (0 == objective_function.compare(0, 9,
"multitask")) {
149 KALDI_LOG << (crossvalidate?
"CROSS-VALIDATION":
"TRAINING") <<
" STARTED";
156 for ( ; !feature_reader.
Done(); feature_reader.
Next()) {
157 std::string utt = feature_reader.
Key();
160 if (!targets_reader.
HasKey(utt)) {
166 if (frame_weights !=
"" && !weights_reader.
HasKey(utt)) {
167 KALDI_WARN << utt <<
", missing per-frame weights";
169 feature_reader.
Next();
176 if (mat.
NumRows() > max_frames) {
178 <<
" that has " << mat.
NumRows() <<
" frames," 179 <<
" it is longer than '--max-frames'" << max_frames;
185 if (frame_weights !=
"") {
186 frm_weights = weights_reader.
Value(utt);
189 frm_weights.
Set(1.0);
194 std::vector<int32> length;
195 length.push_back(mat.
NumRows());
196 length.push_back(nnet_tgt.size());
197 length.push_back(frm_weights.
Dim());
199 int32 min = *std::min_element(length.begin(), length.end());
200 int32 max = *std::max_element(length.begin(), length.end());
202 if (max - min < length_tolerance) {
204 if (nnet_tgt.size() != min) nnet_tgt.resize(min);
207 KALDI_WARN << utt <<
", length mismatch of targets " << nnet_tgt.size()
208 <<
" and features " << mat.
NumRows();
220 if (objective_function ==
"xent") {
222 xent.
Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
223 }
else if (objective_function ==
"mse") {
225 mse.
Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
226 }
else if (0 == objective_function.compare(0, 9,
"multitask")) {
228 multitask.
Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
230 KALDI_ERR <<
"Unknown objective function code : " 231 << objective_function;
234 if (!crossvalidate) {
240 if (total_frames == 0) {
241 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
243 if (!crossvalidate) {
252 static int32 counter = 0;
255 if (counter >= 25000) {
256 KALDI_VLOG(2) <<
"### After " << total_frames <<
" frames,";
258 if (!crossvalidate) {
267 total_frames += frm_weights.
Sum();
271 KALDI_LOG <<
"### After " << total_frames <<
" frames,";
273 if (!crossvalidate) {
278 if (!crossvalidate) {
279 nnet.
Write(target_model_filename, binary);
282 KALDI_LOG <<
"Done " << num_done <<
" files, " 283 << num_no_tgt_mat <<
" with no tgt_mats, " 284 << num_other_error <<
" with other errors. " 285 <<
"[" << (crossvalidate ?
"CROSS-VALIDATION" :
"TRAINING")
286 <<
", " << (randomize ?
"RANDOMIZED" :
"NOT-RANDOMIZED")
287 <<
", " << time.
Elapsed() / 60 <<
" min, processing " 288 << total_frames / time.
Elapsed() <<
" frames per sec.]";
290 if (objective_function ==
"xent") {
293 }
else if (objective_function ==
"mse") {
295 }
else if (0 == objective_function.compare(0, 9,
"multitask")) {
298 KALDI_ERR <<
"Unknown objective function code : " << objective_function;
302 CuDevice::Instantiate().PrintProfile();
306 }
catch(
const std::exception &e) {
307 std::cerr << e.what();
void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,.
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,.
std::string Report()
Generate string with error report.
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',.
bool Open(const std::string &rspecifier)
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::string Report()
Generate string with error report.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const T & Value(const std::string &key)
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,.
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate cross entropy using target-matrix (supports soft labels),.
std::string Report()
Generate string with error report,.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,.
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
void Register(OptionsItf *opts)
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
bool HasKey(const std::string &key)
Real Sum() const
Returns sum of the elements.
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
void Set(Real f)
Set all members of a vector to a specified value.
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,.
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,.
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),.
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
std::string ReportPerClass()
Generate string with per-class error report,.
void Register(OptionsItf *opts)
int main(int argc, char *argv[])
double Elapsed() const
Returns time in seconds.
void InitFromString(const std::string &s)
Initialize from string, the format for string 's' is : 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'.