nnet-combine.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-combine.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
23 #include "nnet2/combine-nnet.h"
24 #include "nnet2/am-nnet.h"
25 
26 
27 int main(int argc, char *argv[]) {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32  typedef kaldi::int64 int64;
33 
34  const char *usage =
35  "Using a validation set, compute an optimal combination of a number of\n"
36  "neural nets (the combination weights are separate for each layer and\n"
37  "do not have to sum to one). The optimization is BFGS, which is initialized\n"
38  "from the best of the individual input neural nets (or as specified by\n"
39  "--initial-model)\n"
40  "\n"
41  "Usage: nnet-combine [options] <model-in1> <model-in2> ... <model-inN> <valid-examples-in> <model-out>\n"
42  "\n"
43  "e.g.:\n"
44  " nnet-combine 1.1.nnet 1.2.nnet 1.3.nnet ark:valid.egs 2.nnet\n"
45  "Caution: the first input neural net must not be a gradient.\n";
46 
47  bool binary_write = true;
48  NnetCombineConfig combine_config;
49 
50  ParseOptions po(usage);
51  po.Register("binary", &binary_write, "Write output in binary mode");
52 
53  combine_config.Register(&po);
54 
55  po.Read(argc, argv);
56 
57  if (po.NumArgs() < 3) {
58  po.PrintUsage();
59  exit(1);
60  }
61 
62  std::string
63  nnet1_rxfilename = po.GetArg(1),
64  valid_examples_rspecifier = po.GetArg(po.NumArgs() - 1),
65  nnet_wxfilename = po.GetArg(po.NumArgs());
66 
67  TransitionModel trans_model;
68  AmNnet am_nnet1;
69  {
70  bool binary_read;
71  Input ki(nnet1_rxfilename, &binary_read);
72  trans_model.Read(ki.Stream(), binary_read);
73  am_nnet1.Read(ki.Stream(), binary_read);
74  }
75 
76  int32 num_nnets = po.NumArgs() - 2;
77  std::vector<Nnet> nnets(num_nnets);
78  nnets[0] = am_nnet1.GetNnet();
79  am_nnet1.GetNnet() = Nnet(); // Clear it to save memory.
80 
81  for (int32 n = 1; n < num_nnets; n++) {
82  TransitionModel trans_model;
83  AmNnet am_nnet;
84  bool binary_read;
85  Input ki(po.GetArg(1 + n), &binary_read);
86  trans_model.Read(ki.Stream(), binary_read);
87  am_nnet.Read(ki.Stream(), binary_read);
88  nnets[n] = am_nnet.GetNnet();
89  }
90 
91  std::vector<NnetExample> validation_set; // stores validation
92  // frames.
93 
94  { // This block adds samples to "validation_set".
95  SequentialNnetExampleReader example_reader(
96  valid_examples_rspecifier);
97  for (; !example_reader.Done(); example_reader.Next())
98  validation_set.push_back(example_reader.Value());
99  KALDI_LOG << "Read " << validation_set.size() << " examples from the "
100  << "validation set.";
101  KALDI_ASSERT(validation_set.size() > 0);
102  }
103 
104  CombineNnets(combine_config,
105  validation_set,
106  nnets,
107  &(am_nnet1.GetNnet()));
108 
109  {
110  Output ko(nnet_wxfilename, binary_write);
111  trans_model.Write(ko.Stream(), binary_write);
112  am_nnet1.Write(ko.Stream(), binary_write);
113  }
114 
115  KALDI_LOG << "Finished combining neural nets, wrote model to "
116  << nnet_wxfilename;
117  return (validation_set.size() == 0 ? 1 : 0);
118  } catch(const std::exception &e) {
119  std::cerr << e.what() << '\n';
120  return -1;
121  }
122 }
123 
124 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
int main(int argc, char *argv[])
Definition: nnet-combine.cc:27
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
void Register(OptionsItf *opts)
Definition: combine-nnet.h:50
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::istream & Stream()
Definition: kaldi-io.cc:826
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
Configuration class that controls neural net combination, where we combine a number of neural nets...
Definition: combine-nnet.h:35
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static void CombineNnets(const Vector< BaseFloat > &scale_params, const std::vector< Nnet > &nnets, Nnet *dest)
Definition: combine-nnet.cc:28
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61