30   using namespace kaldi;
    36       "Perform one iteration (epoch) of Neural Network training with\n"    37       "mini-batch Stochastic Gradient Descent. The training targets\n"    38       "are usually pdf-posteriors, prepared by ali-to-post.\n"    39       "Usage:  nnet-train-frmshuff [options] <feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]\n"    40       "e.g.: nnet-train-frmshuff scp:feats.scp ark:posterior.ark nnet.init nnet.iter1\n";
    52     po.Register(
"binary", &binary, 
"Write output in binary mode");
    54     bool crossvalidate = 
false;
    55     po.Register(
"cross-validate", &crossvalidate,
    56         "Perform cross-validation (don't back-propagate)");
    58     bool randomize = 
true;
    59     po.Register(
"randomize", &randomize,
    60         "Perform the frame-level shuffling within the Cache::");
    62     std::string feature_transform;
    63     po.Register(
"feature-transform", &feature_transform,
    64         "Feature transform in Nnet format");
    66     std::string objective_function = 
"xent";
    67     po.Register(
"objective-function", &objective_function,
    68         "Objective function : xent|mse|multitask");
    70     int32 max_frames = 360000;
    71     po.Register(
"max-frames", &max_frames,
    72         "Maximum number of frames an utterance can have (skipped if longer)");
    74     int32 length_tolerance = 5;
    75     po.Register(
"length-tolerance", &length_tolerance,
    76         "Allowed length mismatch of features/targets/weights "    77         "(in frames, we truncate to the shortest)");
    79     std::string frame_weights;
    80     po.Register(
"frame-weights", &frame_weights,
    81         "Per-frame weights, used to re-scale gradients.");
    83     std::string utt_weights;
    84     po.Register(
"utt-weights", &utt_weights,
    85         "Per-utterance weights, used to re-scale frame-weights.");
    87     std::string use_gpu=
"yes";
    88     po.Register(
"use-gpu", &use_gpu,
    89         "yes|no|optional, only has effect if compiled with CUDA");
    93     if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) {
    98     std::string feature_rspecifier = po.GetArg(1),
    99       targets_rspecifier = po.GetArg(2),
   100       model_filename = po.GetArg(3);
   102     std::string target_model_filename;
   103     if (!crossvalidate) {
   104       target_model_filename = po.GetArg(4);
   107     using namespace kaldi;
   112     CuDevice::Instantiate().SelectGpuId(use_gpu);
   116     if (feature_transform != 
"") {
   117       nnet_transf.
Read(feature_transform);
   121     nnet.
Read(model_filename);
   129     kaldi::int64 total_frames = 0;
   134     if (frame_weights != 
"") {
   135       weights_reader.
Open(frame_weights);
   138     if (utt_weights != 
"") {
   139       utt_weights_reader.
Open(utt_weights);
   147     Xent xent(loss_opts);
   151     if (0 == objective_function.compare(0, 9, 
"multitask")) {
   157       multitask.InitFromString(objective_function);
   163     KALDI_LOG << (crossvalidate ? 
"CROSS-VALIDATION" : 
"TRAINING")
   170     double time_io_accu = 0.0;
   173     while (!feature_reader.Done()) {
   176       CuDevice::Instantiate().CheckGpuHealth();
   180       for ( ; !feature_reader.Done(); feature_reader.Next()) {
   181         if (feature_randomizer.IsFull()) {
   186         std::string utt = feature_reader.Key();
   189         if (!targets_reader.HasKey(utt)) {
   195         if (frame_weights != 
"" && !weights_reader.
HasKey(utt)) {
   196           KALDI_WARN << utt << 
", missing per-frame weights";
   201         if (utt_weights != 
"" && !utt_weights_reader.
HasKey(utt)) {
   202           KALDI_WARN << utt << 
", missing per-utterance weight";
   208         Posterior targets = targets_reader.Value(utt);
   211         if (frame_weights != 
"") {
   212           weights = weights_reader.
Value(utt);
   218         if (utt_weights != 
"") {
   221           if (w == 0.0) 
continue;  
   226         time_io_accu += time_io.
Elapsed();
   230         if (mat.
NumRows() > max_frames) {
   231           KALDI_WARN << 
"Utterance too long, skipping! " << utt
   232             << 
" (length " << mat.
NumRows() << 
", max_frames "   233             << max_frames << 
")";
   241           std::vector<int32> length;
   242           length.push_back(mat.
NumRows());
   243           length.push_back(targets.size());
   244           length.push_back(weights.
Dim());
   246           int32 min = *std::min_element(length.begin(), length.end());
   247           int32 max = *std::max_element(length.begin(), length.end());
   249           if (max - min < length_tolerance) {
   252             if (targets.size() != min) targets.resize(min);
   255             KALDI_WARN << 
"Length mismatch! Targets " << targets.size()
   256                        << 
", features " << mat.
NumRows() << 
", " << utt;
   269           if (weight_min == 0.0) {
   271             std::vector<MatrixIndexT> keep_frames;
   272             for (int32 
i = 0; 
i < weights.
Dim(); 
i++) {
   273               if (weights(
i) > 0.0) {
   274                 keep_frames.push_back(
i);
   279             if (keep_frames.size() == 0) 
continue;
   284             tmp_feats.Swap(&feats_transf);
   288             for (int32 
i = 0; 
i < keep_frames.size(); 
i++) {
   289               tmp_targets.push_back(targets[keep_frames[
i]]);
   291             tmp_targets.swap(targets);
   295             for (int32 
i = 0; 
i < keep_frames.size(); 
i++) {
   296               tmp_weights(
i) = weights(keep_frames[
i]);
   298             tmp_weights.Swap(&weights);
   304         feature_randomizer.AddData(feats_transf);
   305         targets_randomizer.AddData(targets);
   306         weights_randomizer.AddData(weights);
   313       if (!crossvalidate && randomize) {
   314         const std::vector<int32>& mask =
   315           randomizer_mask.Generate(feature_randomizer.NumFrames());
   316         feature_randomizer.Randomize(mask);
   317         targets_randomizer.Randomize(mask);
   318         weights_randomizer.Randomize(mask);
   322       for ( ; !feature_randomizer.Done(); feature_randomizer.Next(),
   323                                           targets_randomizer.Next(),
   324                                           weights_randomizer.Next()) {
   327         const Posterior& nnet_tgt = targets_randomizer.Value();
   334         if (objective_function == 
"xent") {
   336           xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
   337         } 
else if (objective_function == 
"mse") {
   339           mse.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
   340         } 
else if (0 == objective_function.compare(0, 9, 
"multitask")) {
   342           multitask.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
   344           KALDI_ERR << 
"Unknown objective function code : " << objective_function;
   347         if (!crossvalidate) {
   353         if (total_frames == 0) {
   354           KALDI_LOG << 
"### After " << total_frames << 
" frames,";
   356           if (!crossvalidate) {
   365           static int32 counter = 0;
   368           if (counter >= 25000) {
   369             KALDI_VLOG(2) << 
"### After " << total_frames << 
" frames,";
   371             if (!crossvalidate) {
   379         total_frames += nnet_in.
NumRows();
   384     KALDI_LOG << 
"### After " << total_frames << 
" frames,";
   386     if (!crossvalidate) {
   391     if (!crossvalidate) {
   392       nnet.
Write(target_model_filename, binary);
   395     KALDI_LOG << 
"Done " << num_done << 
" files, "   396       << num_no_tgt_mat << 
" with no tgt_mats, "   397       << num_other_error << 
" with other errors. "   398       << 
"[" << (crossvalidate ? 
"CROSS-VALIDATION" : 
"TRAINING")
   399       << 
", " << (randomize ? 
"RANDOMIZED" : 
"NOT-RANDOMIZED")
   400       << 
", " << time.
Elapsed() / 60 << 
" min, processing "   401       << total_frames / time.
Elapsed() << 
" frames per sec;"   402       << 
" i/o time " << 100.*time_io_accu/time.
Elapsed() << 
"%]";
   404     if (objective_function == 
"xent") {
   407     } 
else if (objective_function == 
"mse") {
   409     } 
else if (0 == objective_function.compare(0, 9, 
"multitask")) {
   412       KALDI_ERR << 
"Unknown objective function code : " << objective_function;
   416     CuDevice::Instantiate().PrintProfile();
   420   } 
catch(
const std::exception &e) {
   421     std::cerr << e.what();
 void Backpropagate(const CuMatrixBase< BaseFloat > &out_diff, CuMatrix< BaseFloat > *in_diff)
Perform backward pass through the network,. 
 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
 
void Propagate(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network,. 
 
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix). 
 
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch. 
 
void Write(const std::string &wxfilename, bool binary) const
Write Nnet to 'wxfilename',. 
 
Generates randomly ordered vector of indices,. 
 
bool Open(const std::string &rspecifier)
 
Randomizes elements of a vector according to a mask. 
 
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
 
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero). 
 
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector. 
 
Allows random access to a collection of objects in an archive or script file; see The Table concept...
 
Configuration variables that affect how frame-level shuffling is done. 
 
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
 
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
 
const T & Value(const std::string &key)
 
std::string InfoBackPropagate(bool header=true) const
Create string with back-propagation-buffer statistics,. 
 
void CopyRows(const CuMatrixBase< Real > &src, const CuArrayBase< MatrixIndexT > &indexes)
Copies row r from row indexes[r] of src. 
 
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
 
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',. 
 
void Register(OptionsItf *opts)
 
MatrixIndexT Dim() const
Returns the dimension of the vector. 
 
void Scale(Real alpha)
Multiplies all elements by this constant. 
 
bool HasKey(const std::string &key)
 
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),. 
 
Shuffles rows of a matrix according to the indices in the mask,. 
 
Matrix for CUDA computing. 
 
MatrixIndexT NumCols() const
 
A class representing a vector. 
 
Class CuArray represents a vector of an integer or struct of type T. 
 
#define KALDI_ASSERT(cond)
 
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix). 
 
void Set(Real f)
Set all members of a vector to a specified value. 
 
void Register(OptionsItf *opts)
 
void SetDropoutRate(BaseFloat r)
Set the dropout rate. 
 
std::string InfoGradient(bool header=true) const
Create string with per-component gradient statistics,. 
 
std::string InfoPropagate(bool header=true) const
Create string with propagation-buffer statistics,. 
 
void SetTrainOptions(const NnetTrainOptions &opts)
Set hyper-parameters of the training (pushes to all UpdatableComponents),. 
 
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero). 
 
Randomizes elements of a vector according to a mask. 
 
void Register(OptionsItf *opts)
 
MatrixIndexT NumRows() const
Dimensions. 
 
double Elapsed() const
Returns time in seconds.