29 int main(
int argc, 
char *argv[]) {
    30   using namespace kaldi;
    36       "Perform one iteration of NN training by SGD with per-utterance updates.\n"    37       "The training targets are represented as pdf-posteriors, usually prepared "    39       "Usage: nnet-train-perutt [options] "    40       "<feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"    41       "e.g.: nnet-train-perutt scp:feature.scp ark:posterior.ark nnet.init nnet.iter1\n";
    51     po.
Register(
"binary", &binary, 
"Write output in binary mode");
    53     bool crossvalidate = 
false;
    54     po.
Register(
"cross-validate", &crossvalidate,
    55         "Perform cross-validation (don't backpropagate)");
    57     std::string feature_transform;
    58     po.
Register(
"feature-transform", &feature_transform,
    59         "Feature transform in Nnet format");
    61     std::string objective_function = 
"xent";
    62     po.
Register(
"objective-function", &objective_function,
    63         "Objective function : xent|mse");
    65     int32 length_tolerance = 5;
    66     po.
Register(
"length-tolerance", &length_tolerance,
    67         "Allowed length difference of features/targets (frames)");
    69     std::string frame_weights;
    70     po.
Register(
"frame-weights", &frame_weights,
    71         "Per-frame weights to scale gradients (frame selection/weighting).");
    74     po.
Register(
"max-frames",&max_frames, 
"Maximum number of frames a segment can have to be processed");
    76     std::string use_gpu=
"yes";
    78         "yes|no|optional, only has effect if compiled with CUDA");
    81     bool randomize = 
false;
    83         "Dummy, for compatibility with 'steps/nnet/train_scheduler.sh'");
    88     if (po.
NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
    93     std::string feature_rspecifier = po.
GetArg(1),
    94       targets_rspecifier = po.
GetArg(2),
    95       model_filename = po.
GetArg(3);
    97     std::string target_model_filename;
    99       target_model_filename = po.
GetArg(4);
   102     using namespace kaldi;
   107     CuDevice::Instantiate().SelectGpuId(use_gpu);
   111     if (feature_transform != 
"") {
   112       nnet_transf.
Read(feature_transform);
   116     nnet.
Read(model_filename);
   124     kaldi::int64 total_frames = 0;
   129     if (frame_weights != 
"") {
   130       weights_reader.
Open(frame_weights);
   133     Xent xent(loss_opts);
   137     if (0 == objective_function.compare(0, 9, 
"multitask")) {
   149     KALDI_LOG << (crossvalidate?
"CROSS-VALIDATION":
"TRAINING") << 
" STARTED";
   156     for ( ; !feature_reader.
Done(); feature_reader.
Next()) {
   157       std::string utt = feature_reader.
Key();
   160       if (!targets_reader.
HasKey(utt)) {
   166       if (frame_weights != 
"" && !weights_reader.
HasKey(utt)) {
   167         KALDI_WARN << utt << 
", missing per-frame weights";
   169         feature_reader.
Next();
   176       if (mat.
NumRows() > max_frames) {
   178           << 
" that has " << mat.
NumRows() << 
" frames,"   179           << 
" it is longer than '--max-frames'" << max_frames;
   185       if (frame_weights != 
"") {
   186         frm_weights = weights_reader.
Value(utt);
   189         frm_weights.
Set(1.0);
   194         std::vector<int32> length;
   195         length.push_back(mat.
NumRows());
   196         length.push_back(nnet_tgt.size());
   197         length.push_back(frm_weights.
Dim());
   199         int32 min = *std::min_element(length.begin(), length.end());
   200         int32 max = *std::max_element(length.begin(), length.end());
   202         if (max - min < length_tolerance) {
   204           if (nnet_tgt.size() != min) nnet_tgt.resize(min);
   207           KALDI_WARN << utt << 
", length mismatch of targets " << nnet_tgt.size()
   208                      << 
" and features " << mat.
NumRows();
   220       if (objective_function == 
"xent") {
   222         xent.
Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
   223       } 
else if (objective_function == 
"mse") {
   225         mse.
Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
   226       } 
else if (0 == objective_function.compare(0, 9, 
"multitask")) {
   228         multitask.
Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
   230         KALDI_ERR << 
"Unknown objective function code : "   231                   << objective_function;
   234       if (!crossvalidate) {
   240       if (total_frames == 0) {
   241         KALDI_LOG << 
"### After " << total_frames << 
" frames,";
   243         if (!crossvalidate) {
   252         static int32 counter = 0;
   255         if (counter >= 25000) {
   256           KALDI_VLOG(2) << 
"### After " << total_frames << 
" frames,";
   258           if (!crossvalidate) {
   267       total_frames += frm_weights.
Sum();
   271     KALDI_LOG << 
"### After " << total_frames << 
" frames,";
   273     if (!crossvalidate) {
   278     if (!crossvalidate) {
   279       nnet.
Write(target_model_filename, binary);
   282     KALDI_LOG << 
"Done " << num_done << 
" files, "   283       << num_no_tgt_mat << 
" with no tgt_mats, "   284       << num_other_error << 
" with other errors. "   285       << 
"[" << (crossvalidate ? 
"CROSS-VALIDATION" : 
"TRAINING")
   286       << 
", " << (randomize ? 
"RANDOMIZED" : 
"NOT-RANDOMIZED")
   287       << 
", " << time.
Elapsed() / 60 << 
" min, processing "   288       << total_frames / time.
Elapsed() << 
" frames per sec.]";
   290     if (objective_function == 
"xent") {
   293     } 
else if (objective_function == 
"mse") {
   295     } 
else if (0 == objective_function.compare(0, 9, 
"multitask")) {
   298       KALDI_ERR << 
"Unknown objective function code : " << objective_function;
   302     CuDevice::Instantiate().PrintProfile();
   306   } 
catch(
const std::exception &e) {
   307     std::cerr << e.what();
 void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,. 
 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
 
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,. 
 
std::string Report()
Generate string with error report. 
 
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix). 
 
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor]. 
 
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch. 
 
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',. 
 
bool Open(const std::string &rspecifier)
 
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
 
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero). 
 
void Register(const std::string &name, bool *ptr, const std::string &doc)
 
std::string Report()
Generate string with error report. 
 
Allows random access to a collection of objects in an archive or script file; see The Table concept...
 
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
 
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,. 
 
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
 
const T & Value(const std::string &key)
 
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,. 
 
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate cross entropy using target-matrix (supports soft labels),. 
 
std::string Report()
Generate string with error report,. 
 
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
 
void Eval(const VectorBase< BaseFloat > &frame_weights, const CuMatrixBase< BaseFloat > &net_out, const CuMatrixBase< BaseFloat > &target, CuMatrix< BaseFloat > *diff)
Evaluate mean square error using target-matrix,. 
 
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables. 
 
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',. 
 
void Register(OptionsItf *opts)
 
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility. 
 
MatrixIndexT Dim() const
Returns the dimension of the vector. 
 
bool HasKey(const std::string &key)
 
Real Sum() const
Returns sum of the elements. 
 
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),. 
 
int NumArgs() const
Number of positional parameters (c.f. argc-1). 
 
A class representing a vector. 
 
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix). 
 
void Set(Real f)
Set all members of a vector to a specified value. 
 
void SetDropoutRate(BaseFloat r)
Set the dropout rate. 
 
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,. 
 
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,. 
 
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),. 
 
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero). 
 
std::string ReportPerClass()
Generate string with per-class error report,. 
 
void Register(OptionsItf *opts)
 
int main(int argc, char *argv[])
 
double Elapsed() const
Returns time in seconds. 
 
void InitFromString(const std::string &s)
Initialize from string, the format for string 's' is : 'multitask,<type1>,<dim1>,<weight1>,...,<typeN>,<dimN>,<weightN>'.