nnet-combine.cc File Reference
Include dependency graph for nnet-combine.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file nnet-combine.cc.

References kaldi::nnet2::CombineNnets(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), AmNnet::GetNnet(), KALDI_ASSERT, KALDI_LOG, rnnlm::n, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), NnetCombineConfig::Register(), ParseOptions::Register(), Output::Stream(), Input::Stream(), SequentialTableReader< Holder >::Value(), AmNnet::Write(), and TransitionModel::Write().

27  {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32  typedef kaldi::int64 int64;
33 
34  const char *usage =
35  "Using a validation set, compute an optimal combination of a number of\n"
36  "neural nets (the combination weights are separate for each layer and\n"
37  "do not have to sum to one). The optimization is BFGS, which is initialized\n"
38  "from the best of the individual input neural nets (or as specified by\n"
39  "--initial-model)\n"
40  "\n"
41  "Usage: nnet-combine [options] <model-in1> <model-in2> ... <model-inN> <valid-examples-in> <model-out>\n"
42  "\n"
43  "e.g.:\n"
44  " nnet-combine 1.1.nnet 1.2.nnet 1.3.nnet ark:valid.egs 2.nnet\n"
45  "Caution: the first input neural net must not be a gradient.\n";
46 
47  bool binary_write = true;
48  NnetCombineConfig combine_config;
49 
50  ParseOptions po(usage);
51  po.Register("binary", &binary_write, "Write output in binary mode");
52 
53  combine_config.Register(&po);
54 
55  po.Read(argc, argv);
56 
57  if (po.NumArgs() < 3) {
58  po.PrintUsage();
59  exit(1);
60  }
61 
62  std::string
63  nnet1_rxfilename = po.GetArg(1),
64  valid_examples_rspecifier = po.GetArg(po.NumArgs() - 1),
65  nnet_wxfilename = po.GetArg(po.NumArgs());
66 
67  TransitionModel trans_model;
68  AmNnet am_nnet1;
69  {
70  bool binary_read;
71  Input ki(nnet1_rxfilename, &binary_read);
72  trans_model.Read(ki.Stream(), binary_read);
73  am_nnet1.Read(ki.Stream(), binary_read);
74  }
75 
76  int32 num_nnets = po.NumArgs() - 2;
77  std::vector<Nnet> nnets(num_nnets);
78  nnets[0] = am_nnet1.GetNnet();
79  am_nnet1.GetNnet() = Nnet(); // Clear it to save memory.
80 
81  for (int32 n = 1; n < num_nnets; n++) {
82  TransitionModel trans_model;
83  AmNnet am_nnet;
84  bool binary_read;
85  Input ki(po.GetArg(1 + n), &binary_read);
86  trans_model.Read(ki.Stream(), binary_read);
87  am_nnet.Read(ki.Stream(), binary_read);
88  nnets[n] = am_nnet.GetNnet();
89  }
90 
91  std::vector<NnetExample> validation_set; // stores validation
92  // frames.
93 
94  { // This block adds samples to "validation_set".
95  SequentialNnetExampleReader example_reader(
96  valid_examples_rspecifier);
97  for (; !example_reader.Done(); example_reader.Next())
98  validation_set.push_back(example_reader.Value());
99  KALDI_LOG << "Read " << validation_set.size() << " examples from the "
100  << "validation set.";
101  KALDI_ASSERT(validation_set.size() > 0);
102  }
103 
104  CombineNnets(combine_config,
105  validation_set,
106  nnets,
107  &(am_nnet1.GetNnet()));
108 
109  {
110  Output ko(nnet_wxfilename, binary_write);
111  trans_model.Write(ko.Stream(), binary_write);
112  am_nnet1.Write(ko.Stream(), binary_write);
113  }
114 
115  KALDI_LOG << "Finished combining neural nets, wrote model to "
116  << nnet_wxfilename;
117  return (validation_set.size() == 0 ? 1 : 0);
118  } catch(const std::exception &e) {
119  std::cerr << e.what() << '\n';
120  return -1;
121  }
122 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
void Register(OptionsItf *opts)
Definition: combine-nnet.h:50
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
Configuration class that controls neural net combination, where we combine a number of neural nets...
Definition: combine-nnet.h:35
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
static void CombineNnets(const Vector< BaseFloat > &scale_params, const std::vector< Nnet > &nnets, Nnet *dest)
Definition: combine-nnet.cc:28
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61