nnet3-compute-from-egs.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-compute-from-egs.cc
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
23 #include "nnet3/nnet-nnet.h"
25 #include "nnet3/nnet-optimize.h"
26 #include "transform/lda-estimate.h"
27 
28 
29 namespace kaldi {
30 namespace nnet3 {
31 
33  public:
34  NnetComputerFromEg(const Nnet &nnet):
35  nnet_(nnet), compiler_(nnet) { }
36 
37  // Compute the output (which will have the same number of rows as the number
38  // of Indexes in the output with the name 'output_name' of the eg),
39  // and put it in "*output".
40  // An output with the name 'output_name' is expected to exist in the network.
41  void Compute(const NnetExample &eg, const std::string &output_name,
42  Matrix<BaseFloat> *output) {
43  ComputationRequest request;
44  bool need_backprop = false, store_stats = false;
45  GetComputationRequest(nnet_, eg, need_backprop, store_stats, &request);
46  const NnetComputation &computation = *(compiler_.Compile(request));
47  NnetComputeOptions options;
48  if (GetVerboseLevel() >= 3)
49  options.debug = true;
50  NnetComputer computer(options, computation, nnet_, NULL);
51  computer.AcceptInputs(nnet_, eg.io);
52  computer.Run();
53  const CuMatrixBase<BaseFloat> &nnet_output = computer.GetOutput(output_name);
54  output->Resize(nnet_output.NumRows(), nnet_output.NumCols());
55  nnet_output.CopyToMat(output);
56  }
57  private:
58  const Nnet &nnet_;
60 
61 };
62 
63 }
64 }
65 
66 int main(int argc, char *argv[]) {
67  try {
68  using namespace kaldi;
69  using namespace kaldi::nnet3;
70  typedef kaldi::int32 int32;
71  typedef kaldi::int64 int64;
72 
73  const char *usage =
74  "Read input nnet training examples, and compute the output for each one.\n"
75  "If --apply-exp=true, apply the Exp() function to the output before writing\n"
76  "it out.\n"
77  "\n"
78  "Usage: nnet3-compute-from-egs [options] <raw-nnet-in> <training-examples-in> <matrices-out>\n"
79  "e.g.:\n"
80  "nnet3-compute-from-egs --apply-exp=true 0.raw ark:1.egs ark:- | matrix-sum-rows ark:- ... \n"
81  "See also: nnet3-compute\n";
82 
83  bool binary_write = true,
84  apply_exp = false;
85  std::string use_gpu = "yes";
86  std::string output_name = "output";
87 
88  ParseOptions po(usage);
89  po.Register("binary", &binary_write, "Write output in binary mode");
90  po.Register("apply-exp", &apply_exp, "If true, apply exp function to "
91  "output");
92  po.Register("output-name", &output_name, "Do computation for "
93  "specified output-node");
94  po.Register("use-gpu", &use_gpu,
95  "yes|no|optional|wait, only has effect if compiled with CUDA");
96 
97  po.Read(argc, argv);
98 
99  if (po.NumArgs() != 3) {
100  po.PrintUsage();
101  exit(1);
102  }
103 
104 #if HAVE_CUDA==1
105  CuDevice::Instantiate().SelectGpuId(use_gpu);
106 #endif
107 
108  std::string nnet_rxfilename = po.GetArg(1),
109  examples_rspecifier = po.GetArg(2),
110  matrix_wspecifier = po.GetArg(3);
111 
112  Nnet nnet;
113  ReadKaldiObject(nnet_rxfilename, &nnet);
114 
115  NnetComputerFromEg computer(nnet);
116 
117  int64 num_egs = 0;
118 
119  SequentialNnetExampleReader example_reader(examples_rspecifier);
120  BaseFloatMatrixWriter matrix_writer(matrix_wspecifier);
121 
122  for (; !example_reader.Done(); example_reader.Next(), num_egs++) {
123  Matrix<BaseFloat> output;
124  computer.Compute(example_reader.Value(), output_name, &output);
125  KALDI_ASSERT(output.NumRows() != 0);
126  if (apply_exp)
127  output.ApplyExp();
128  matrix_writer.Write(example_reader.Key(), output);
129  }
130 #if HAVE_CUDA==1
131  CuDevice::Instantiate().PrintProfile();
132 #endif
133  KALDI_LOG << "Processed " << num_egs << " examples.";
134  return 0;
135  } catch(const std::exception &e) {
136  std::cerr << e.what() << '\n';
137  return -1;
138  }
139 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CopyToMat(MatrixBase< OtherReal > *dst, MatrixTransposeType trans=kNoTrans) const
Definition: cu-matrix.cc:447
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int32 GetVerboseLevel()
Get verbosity level, usually set via command line &#39;–verbose=&#39; switch.
Definition: kaldi-error.h:60
void Compute(const NnetExample &eg, const std::string &output_name, Matrix< BaseFloat > *output)
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
int main(int argc, char *argv[])
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
int NumArgs() const
Number of positional parameters (c.f. argc-1).
Matrix for CUDA computing.
Definition: matrix-common.h:69
MatrixIndexT NumCols() const
Definition: cu-matrix.h:216
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void Resize(const MatrixIndexT r, const MatrixIndexT c, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Sets matrix to a specified size (zero is OK as long as both r and c are zero).
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
#define KALDI_LOG
Definition: kaldi-error.h:153
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...