nnet-combine-fast.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-combine-fast.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
24 #include "nnet2/am-nnet.h"
25 
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32  typedef kaldi::int64 int64;
33 
34  const char *usage =
35  "Using a validation set, compute an optimal combination of a number of\n"
36  "neural nets (the combination weights are separate for each layer and\n"
37  "do not have to sum to one). The optimization is BFGS, which is initialized\n"
38  "from the best of the individual input neural nets (or as specified by\n"
39  "--initial-model)\n"
40  "\n"
41  "Usage: nnet-combine-fast [options] <model-in1> <model-in2> ... <model-inN> <valid-examples-in> <model-out>\n"
42  "\n"
43  "e.g.:\n"
44  " nnet-combine-fast 1.1.nnet 1.2.nnet 1.3.nnet ark:valid.egs 2.nnet\n"
45  "Caution: the first input neural net must not be a gradient.\n";
46 
47  bool binary_write = true;
48  NnetCombineFastConfig combine_config;
49  std::string use_gpu = "yes";
50 
51  ParseOptions po(usage);
52  po.Register("binary", &binary_write, "Write output in binary mode");
53  po.Register("use-gpu", &use_gpu,
54  "yes|no|optional|wait, only has effect if compiled with CUDA");
55 
56  combine_config.Register(&po);
57 
58  po.Read(argc, argv);
59 
60  if (po.NumArgs() < 3) {
61  po.PrintUsage();
62  exit(1);
63  }
64 
65  std::string
66  nnet1_rxfilename = po.GetArg(1),
67  valid_examples_rspecifier = po.GetArg(po.NumArgs() - 1),
68  nnet_wxfilename = po.GetArg(po.NumArgs());
69 
70 #if HAVE_CUDA==1
71  if (combine_config.num_threads == 1)
72  CuDevice::Instantiate().SelectGpuId(use_gpu);
73 #endif
74 
75 
76  TransitionModel trans_model;
77  AmNnet am_nnet1;
78  {
79  bool binary_read;
80  Input ki(nnet1_rxfilename, &binary_read);
81  trans_model.Read(ki.Stream(), binary_read);
82  am_nnet1.Read(ki.Stream(), binary_read);
83  }
84 
85  int32 num_nnets = po.NumArgs() - 2;
86  std::vector<Nnet> nnets(num_nnets);
87  nnets[0] = am_nnet1.GetNnet();
88  am_nnet1.GetNnet() = Nnet(); // Clear it to save memory.
89 
90  for (int32 n = 1; n < num_nnets; n++) {
91  TransitionModel trans_model;
92  AmNnet am_nnet;
93  bool binary_read;
94  Input ki(po.GetArg(1 + n), &binary_read);
95  trans_model.Read(ki.Stream(), binary_read);
96  am_nnet.Read(ki.Stream(), binary_read);
97  nnets[n] = am_nnet.GetNnet();
98  }
99 
100  std::vector<NnetExample> validation_set; // stores validation
101  // frames.
102 
103  { // This block adds samples to "validation_set".
104  SequentialNnetExampleReader example_reader(
105  valid_examples_rspecifier);
106  for (; !example_reader.Done(); example_reader.Next())
107  validation_set.push_back(example_reader.Value());
108  KALDI_LOG << "Read " << validation_set.size() << " examples from the "
109  << "validation set.";
110  KALDI_ASSERT(validation_set.size() > 0);
111  }
112 
113  CombineNnetsFast(combine_config,
114  validation_set,
115  nnets,
116  &(am_nnet1.GetNnet()));
117 
118  {
119  Output ko(nnet_wxfilename, binary_write);
120  trans_model.Write(ko.Stream(), binary_write);
121  am_nnet1.Write(ko.Stream(), binary_write);
122  }
123 
124  KALDI_LOG << "Finished combining neural nets, wrote model to "
125  << nnet_wxfilename;
126  return (validation_set.size() == 0 ? 1 : 0);
127  } catch(const std::exception &e) {
128  std::cerr << e.what() << '\n';
129  return -1;
130  }
131 }
132 
133 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
void Register(const std::string &name, bool *ptr, const std::string &doc)
int main(int argc, char *argv[])
void CombineNnetsFast(const NnetCombineFastConfig &combine_config, const std::vector< NnetExample > &validation_set, const std::vector< Nnet > &nnets_in, Nnet *nnet_out)
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
Configuration class that controls neural net combination, where we combine a number of neural nets...
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61