nnet3-am-train-transitions.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-am-train-transitions.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "hmm/transition-model.h"
23 #include "nnet3/am-nnet-simple.h"
24 #include "tree/context-dep.h"
25 
26 namespace kaldi {
27 namespace nnet3 {
28 void SetPriors(const TransitionModel &tmodel,
29  const Vector<double> &transition_accs,
30  double prior_floor,
31  AmNnetSimple *am_nnet) {
32  KALDI_ASSERT(tmodel.NumPdfs() == am_nnet->NumPdfs());
33  Vector<BaseFloat> pdf_counts(tmodel.NumPdfs());
34  KALDI_ASSERT(transition_accs(0) == 0.0); // There is
35  // no zero transition-id.
36  for (int32 tid = 1; tid < transition_accs.Dim(); tid++) {
37  int32 pdf = tmodel.TransitionIdToPdf(tid);
38  pdf_counts(pdf) += transition_accs(tid);
39  }
40  BaseFloat sum = pdf_counts.Sum();
41  KALDI_ASSERT(sum != 0.0);
42  KALDI_ASSERT(prior_floor > 0.0 && prior_floor < 1.0);
43  pdf_counts.Scale(1.0 / sum);
44  pdf_counts.ApplyFloor(prior_floor);
45  pdf_counts.Scale(1.0 / pdf_counts.Sum()); // normalize again.
46  am_nnet->SetPriors(pdf_counts);
47 }
48 
49 
50 } // namespace nnet3
51 } // namespace kaldi
52 
53 int main(int argc, char *argv[]) {
54  try {
55  using namespace kaldi;
56  using namespace kaldi::nnet3;
57  typedef kaldi::int32 int32;
58 
59  const char *usage =
60  "Train the transition probabilities of an nnet3 neural network acoustic model\n"
61  "\n"
62  "Usage: nnet3-am-train-transitions [options] <nnet-in> <alignments-rspecifier> <nnet-out>\n"
63  "e.g.:\n"
64  " nnet3-am-train-transitions 1.nnet \"ark:gunzip -c ali.*.gz|\" 2.nnet\n";
65 
66  bool binary_write = true;
67  bool set_priors = true; // Also set the per-pdf priors in the model.
68  BaseFloat prior_floor = 5.0e-06; // The default was previously 1e-8, but
69  // once we had problems with a pdf-id that
70  // was not being seen in training, being
71  // recognized all the time. This value
72  // seemed to be the smallest prior of the
73  // "seen" pdf-ids in one run.
74  MleTransitionUpdateConfig transition_update_config;
75 
76  ParseOptions po(usage);
77  po.Register("binary", &binary_write, "Write output in binary mode");
78  po.Register("set-priors", &set_priors, "If true, also set priors in neural "
79  "net (we divide by these in test time)");
80  po.Register("prior-floor", &prior_floor, "When setting priors, floor for "
81  "priors");
82  transition_update_config.Register(&po);
83 
84  po.Read(argc, argv);
85 
86  if (po.NumArgs() != 3) {
87  po.PrintUsage();
88  exit(1);
89  }
90 
91  std::string nnet_rxfilename = po.GetArg(1),
92  ali_rspecifier = po.GetArg(2),
93  nnet_wxfilename = po.GetArg(3);
94 
95  TransitionModel trans_model;
96  AmNnetSimple am_nnet;
97  {
98  bool binary_read;
99  Input ki(nnet_rxfilename, &binary_read);
100  trans_model.Read(ki.Stream(), binary_read);
101  am_nnet.Read(ki.Stream(), binary_read);
102  }
103 
104  Vector<double> transition_accs;
105  trans_model.InitStats(&transition_accs);
106 
107  int32 num_done = 0;
108  SequentialInt32VectorReader ali_reader(ali_rspecifier);
109  for (; ! ali_reader.Done(); ali_reader.Next()) {
110  const std::vector<int32> alignment(ali_reader.Value());
111  for (size_t i = 0; i < alignment.size(); i++) {
112  int32 tid = alignment[i];
113  BaseFloat weight = 1.0;
114  trans_model.Accumulate(weight, tid, &transition_accs);
115  }
116  num_done++;
117  }
118  KALDI_LOG << "Accumulated transition stats from " << num_done
119  << " utterances.";
120 
121  {
122  BaseFloat objf_impr, count;
123  trans_model.MleUpdate(transition_accs, transition_update_config,
124  &objf_impr, &count);
125  KALDI_LOG << "Transition model update: average " << (objf_impr/count)
126  << " log-like improvement per frame over " << count
127  << " frames.";
128  }
129 
130  if (set_priors) {
131  KALDI_LOG << "Setting priors of pdfs in the model.";
132  SetPriors(trans_model, transition_accs, prior_floor, &am_nnet);
133  }
134 
135  {
136  Output ko(nnet_wxfilename, binary_write);
137  trans_model.Write(ko.Stream(), binary_write);
138  am_nnet.Write(ko.Stream(), binary_write);
139  }
140  KALDI_LOG << "Trained transitions of neural network model and wrote it to "
141  << nnet_wxfilename;
142  return 0;
143  } catch(const std::exception &e) {
144  std::cerr << e.what() << '\n';
145  return -1;
146  }
147 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
void SetPriors(const TransitionModel &tmodel, const Vector< double > &transition_accs, double prior_floor, AmNnetSimple *am_nnet)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
int32 TransitionIdToPdf(int32 trans_id) const
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
int main(int argc, char *argv[])
void SetPriors(const VectorBase< BaseFloat > &priors)
const size_t count
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
void InitStats(Vector< double > *stats) const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
std::ostream & Stream()
Definition: kaldi-io.cc:701
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void Write(std::ostream &os, bool binary) const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void Register(OptionsItf *opts)
#define KALDI_LOG
Definition: kaldi-error.h:153
void Write(std::ostream &os, bool binary) const