nnet-combine-fast.cc File Reference
Include dependency graph for nnet-combine-fast.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 27 of file nnet-combine-fast.cc.

References kaldi::nnet2::CombineNnetsFast(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), AmNnet::GetNnet(), KALDI_ASSERT, KALDI_LOG, rnnlm::n, SequentialTableReader< Holder >::Next(), NnetCombineFastConfig::num_threads, ParseOptions::NumArgs(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), NnetCombineFastConfig::Register(), ParseOptions::Register(), Output::Stream(), Input::Stream(), SequentialTableReader< Holder >::Value(), AmNnet::Write(), and TransitionModel::Write().

27  {
28  try {
29  using namespace kaldi;
30  using namespace kaldi::nnet2;
31  typedef kaldi::int32 int32;
32  typedef kaldi::int64 int64;
33 
34  const char *usage =
35  "Using a validation set, compute an optimal combination of a number of\n"
36  "neural nets (the combination weights are separate for each layer and\n"
37  "do not have to sum to one). The optimization is BFGS, which is initialized\n"
38  "from the best of the individual input neural nets (or as specified by\n"
39  "--initial-model)\n"
40  "\n"
41  "Usage: nnet-combine-fast [options] <model-in1> <model-in2> ... <model-inN> <valid-examples-in> <model-out>\n"
42  "\n"
43  "e.g.:\n"
44  " nnet-combine-fast 1.1.nnet 1.2.nnet 1.3.nnet ark:valid.egs 2.nnet\n"
45  "Caution: the first input neural net must not be a gradient.\n";
46 
47  bool binary_write = true;
48  NnetCombineFastConfig combine_config;
49  std::string use_gpu = "yes";
50 
51  ParseOptions po(usage);
52  po.Register("binary", &binary_write, "Write output in binary mode");
53  po.Register("use-gpu", &use_gpu,
54  "yes|no|optional|wait, only has effect if compiled with CUDA");
55 
56  combine_config.Register(&po);
57 
58  po.Read(argc, argv);
59 
60  if (po.NumArgs() < 3) {
61  po.PrintUsage();
62  exit(1);
63  }
64 
65  std::string
66  nnet1_rxfilename = po.GetArg(1),
67  valid_examples_rspecifier = po.GetArg(po.NumArgs() - 1),
68  nnet_wxfilename = po.GetArg(po.NumArgs());
69 
70 #if HAVE_CUDA==1
71  if (combine_config.num_threads == 1)
72  CuDevice::Instantiate().SelectGpuId(use_gpu);
73 #endif
74 
75 
76  TransitionModel trans_model;
77  AmNnet am_nnet1;
78  {
79  bool binary_read;
80  Input ki(nnet1_rxfilename, &binary_read);
81  trans_model.Read(ki.Stream(), binary_read);
82  am_nnet1.Read(ki.Stream(), binary_read);
83  }
84 
85  int32 num_nnets = po.NumArgs() - 2;
86  std::vector<Nnet> nnets(num_nnets);
87  nnets[0] = am_nnet1.GetNnet();
88  am_nnet1.GetNnet() = Nnet(); // Clear it to save memory.
89 
90  for (int32 n = 1; n < num_nnets; n++) {
91  TransitionModel trans_model;
92  AmNnet am_nnet;
93  bool binary_read;
94  Input ki(po.GetArg(1 + n), &binary_read);
95  trans_model.Read(ki.Stream(), binary_read);
96  am_nnet.Read(ki.Stream(), binary_read);
97  nnets[n] = am_nnet.GetNnet();
98  }
99 
100  std::vector<NnetExample> validation_set; // stores validation
101  // frames.
102 
103  { // This block adds samples to "validation_set".
104  SequentialNnetExampleReader example_reader(
105  valid_examples_rspecifier);
106  for (; !example_reader.Done(); example_reader.Next())
107  validation_set.push_back(example_reader.Value());
108  KALDI_LOG << "Read " << validation_set.size() << " examples from the "
109  << "validation set.";
110  KALDI_ASSERT(validation_set.size() > 0);
111  }
112 
113  CombineNnetsFast(combine_config,
114  validation_set,
115  nnets,
116  &(am_nnet1.GetNnet()));
117 
118  {
119  Output ko(nnet_wxfilename, binary_write);
120  trans_model.Write(ko.Stream(), binary_write);
121  am_nnet1.Write(ko.Stream(), binary_write);
122  }
123 
124  KALDI_LOG << "Finished combining neural nets, wrote model to "
125  << nnet_wxfilename;
126  return (validation_set.size() == 0 ? 1 : 0);
127  } catch(const std::exception &e) {
128  std::cerr << e.what() << '\n';
129  return -1;
130  }
131 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
void CombineNnetsFast(const NnetCombineFastConfig &combine_config, const std::vector< NnetExample > &validation_set, const std::vector< Nnet > &nnets_in, Nnet *nnet_out)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
struct rnnlm::@11::@12 n
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
Configuration class that controls neural net combination, where we combine a number of neural nets...
#define KALDI_LOG
Definition: kaldi-error.h:153
const Nnet & GetNnet() const
Definition: am-nnet.h:61