All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnet-forward.cc File Reference
#include <limits>
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "nnet/nnet-pdf-prior.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "base/timer.h"
Include dependency graph for nnet-forward.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 30 of file nnet-forward.cc.

References CuMatrixBase< Real >::Add(), CuMatrixBase< Real >::ApplyLog(), PdfPriorOptions::class_frame_counts, SequentialTableReader< Holder >::Done(), Timer::Elapsed(), Nnet::Feedforward(), ParseOptions::GetArg(), Nnet::GetLastComponent(), Component::GetType(), kaldi::GetVerboseLevel(), KALDI_ERR, KALDI_ISFINITE, KALDI_LOG, KALDI_VLOG, KALDI_WARN, Component::kBlockSoftmax, SequentialTableReader< Holder >::Key(), Component::kSoftmax, CuMatrixBase< Real >::Max(), CuMatrixBase< Real >::Min(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), Nnet::Read(), PdfPriorOptions::Register(), ParseOptions::Register(), Nnet::RemoveLastComponent(), Nnet::SetDropoutRate(), PdfPrior::SubtractOnLogpost(), MatrixBase< Real >::Sum(), CuMatrixBase< Real >::Sum(), Component::TypeToMarker(), SequentialTableReader< Holder >::Value(), and TableWriter< Holder >::Write().

30  {
31  using namespace kaldi;
32  using namespace kaldi::nnet1;
33  try {
34  const char *usage =
35  "Perform forward pass through Neural Network.\n"
36  "Usage: nnet-forward [options] <nnet1-in> <feature-rspecifier> <feature-wspecifier>\n"
37  "e.g.: nnet-forward final.nnet ark:input.ark ark:output.ark\n";
38 
39  ParseOptions po(usage);
40 
41  PdfPriorOptions prior_opts;
42  prior_opts.Register(&po);
43 
44  std::string feature_transform;
45  po.Register("feature-transform", &feature_transform,
46  "Feature transform in front of main network (in nnet format)");
47 
48  bool no_softmax = false;
49  po.Register("no-softmax", &no_softmax,
50  "Removes the last component with Softmax, if found. The pre-softmax "
51  "activations are the output of the network. Decoding them leads to "
52  "the same lattices as if we had used 'log-posteriors'.");
53 
54  bool apply_log = false;
55  po.Register("apply-log", &apply_log, "Transform NN output by log()");
56 
57  std::string use_gpu="no";
58  po.Register("use-gpu", &use_gpu,
59  "yes|no|optional, only has effect if compiled with CUDA");
60 
61  using namespace kaldi;
62  using namespace kaldi::nnet1;
63  typedef kaldi::int32 int32;
64 
65  po.Read(argc, argv);
66 
67  if (po.NumArgs() != 3) {
68  po.PrintUsage();
69  exit(1);
70  }
71 
72  std::string model_filename = po.GetArg(1),
73  feature_rspecifier = po.GetArg(2),
74  feature_wspecifier = po.GetArg(3);
75 
76  // Select the GPU
77 #if HAVE_CUDA == 1
78  CuDevice::Instantiate().SelectGpuId(use_gpu);
79 #endif
80 
81  Nnet nnet_transf;
82  if (feature_transform != "") {
83  nnet_transf.Read(feature_transform);
84  }
85 
86  Nnet nnet;
87  nnet.Read(model_filename);
88  // optionally remove softmax,
89  Component::ComponentType last_comp_type = nnet.GetLastComponent().GetType();
90  if (no_softmax) {
91  if (last_comp_type == Component::kSoftmax ||
92  last_comp_type == Component::kBlockSoftmax) {
93  KALDI_LOG << "Removing " << Component::TypeToMarker(last_comp_type)
94  << " from the nnet " << model_filename;
95  nnet.RemoveLastComponent();
96  } else {
97  KALDI_WARN << "Last component 'NOT-REMOVED' by --no-softmax=true, "
98  << "the component was " << Component::TypeToMarker(last_comp_type);
99  }
100  }
101 
102  // avoid some bad option combinations,
103  if (apply_log && no_softmax) {
104  KALDI_ERR << "Cannot use both --apply-log=true --no-softmax=true, "
105  << "use only one of the two!";
106  }
107 
108  // we will subtract log-priors later,
109  PdfPrior pdf_prior(prior_opts);
110 
111  // disable dropout,
112  nnet_transf.SetDropoutRate(0.0);
113  nnet.SetDropoutRate(0.0);
114 
115  kaldi::int64 tot_t = 0;
116 
117  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
118  BaseFloatMatrixWriter feature_writer(feature_wspecifier);
119 
120  CuMatrix<BaseFloat> feats, feats_transf, nnet_out;
121  Matrix<BaseFloat> nnet_out_host;
122 
123  Timer time;
124  double time_now = 0;
125  int32 num_done = 0;
126 
127  // main loop,
128  for (; !feature_reader.Done(); feature_reader.Next()) {
129  // read
130  Matrix<BaseFloat> mat = feature_reader.Value();
131  std::string utt = feature_reader.Key();
132  KALDI_VLOG(2) << "Processing utterance " << num_done+1
133  << ", " << utt
134  << ", " << mat.NumRows() << "frm";
135 
136 
137  if (!KALDI_ISFINITE(mat.Sum())) { // check there's no nan/inf,
138  KALDI_ERR << "NaN or inf found in features for " << utt;
139  }
140 
141  // push it to gpu,
142  feats = mat;
143 
144  // fwd-pass, feature transform,
145  nnet_transf.Feedforward(feats, &feats_transf);
146  if (!KALDI_ISFINITE(feats_transf.Sum())) { // check there's no nan/inf,
147  KALDI_ERR << "NaN or inf found in transformed-features for " << utt;
148  }
149 
150  // fwd-pass, nnet,
151  nnet.Feedforward(feats_transf, &nnet_out);
152  if (!KALDI_ISFINITE(nnet_out.Sum())) { // check there's no nan/inf,
153  KALDI_ERR << "NaN or inf found in nn-output for " << utt;
154  }
155 
156  // convert posteriors to log-posteriors,
157  if (apply_log) {
158  if (!(nnet_out.Min() >= 0.0 && nnet_out.Max() <= 1.0)) {
159  KALDI_WARN << "Applying 'log()' to data which don't seem to be "
160  << "probabilities," << utt;
161  }
162  nnet_out.Add(1e-20); // avoid log(0),
163  nnet_out.ApplyLog();
164  }
165 
166  // subtract log-priors from log-posteriors or pre-softmax,
167  if (prior_opts.class_frame_counts != "") {
168  pdf_prior.SubtractOnLogpost(&nnet_out);
169  }
170 
171  // download from GPU,
172  nnet_out_host = Matrix<BaseFloat>(nnet_out);
173 
174  // write,
175  if (!KALDI_ISFINITE(nnet_out_host.Sum())) { // check there's no nan/inf,
176  KALDI_ERR << "NaN or inf found in final output nn-output for " << utt;
177  }
178  feature_writer.Write(feature_reader.Key(), nnet_out_host);
179 
180  // progress log,
181  if (num_done % 100 == 0) {
182  time_now = time.Elapsed();
183  KALDI_VLOG(1) << "After " << num_done << " utterances: time elapsed = "
184  << time_now/60 << " min; processed " << tot_t/time_now
185  << " frames per second.";
186  }
187  num_done++;
188  tot_t += mat.NumRows();
189  }
190 
191  // final message,
192  KALDI_LOG << "Done " << num_done << " files"
193  << " in " << time.Elapsed()/60 << "min,"
194  << " (fps " << tot_t/time.Elapsed() << ")";
195 
196 #if HAVE_CUDA == 1
197  if (GetVerboseLevel() >= 1) {
198  CuDevice::Instantiate().PrintProfile();
199  }
200 #endif
201 
202  if (num_done == 0) return -1;
203  return 0;
204  } catch(const std::exception &e) {
205  std::cerr << e.what();
206  return -1;
207  }
208 }
void RemoveLastComponent()
Remove the last of the Components,.
Definition: nnet-nnet.cc:206
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
Real Sum() const
Definition: cu-matrix.cc:2658
Real Sum() const
Returns sum of all elements in matrix.
int32 GetVerboseLevel()
Definition: kaldi-error.h:69
#define KALDI_ISFINITE(x)
Definition: kaldi-math.h:74
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:366
Real Min() const
Definition: cu-matrix.cc:2700
const Component & GetLastComponent() const
LastComponent accessor,.
Definition: nnet-nnet.cc:161
double Elapsed()
Returns time in seconds.
Definition: timer.h:65
ComponentType
Component type identification mechanism,.
static const char * TypeToMarker(ComponentType t)
Converts component type to marker,.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Max(const CuMatrixBase< Real > &A)
Do, elementwise, *this = max(*this, A).
Definition: cu-matrix.cc:700
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:127
void Read(const std::string &rxfilename)
Read Nnet from 'rxfilename',.
Definition: nnet-nnet.cc:333
#define KALDI_WARN
Definition: kaldi-error.h:130
void Feedforward(const CuMatrixBase< BaseFloat > &in, CuMatrix< BaseFloat > *out)
Perform forward pass through the network (with 2 swapping buffers),.
Definition: nnet-nnet.cc:131
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
void SetDropoutRate(BaseFloat r)
Set the dropout rate.
Definition: nnet-nnet.cc:268
#define KALDI_VLOG(v)
Definition: kaldi-error.h:136
virtual ComponentType GetType() const =0
Get Type Identification of the component,.
void Add(Real value)
Definition: cu-matrix.cc:546
#define KALDI_LOG
Definition: kaldi-error.h:133
void Register(OptionsItf *opts)