nnet-train-transitions.cc File Reference
Include dependency graph for nnet-train-transitions.cc:

Go to the source code of this file.

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet2
 

Functions

void SetPriors (const TransitionModel &tmodel, const Vector< double > &transition_accs, double prior_floor, AmNnet *am_nnet)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 53 of file nnet-train-transitions.cc.

References TransitionModel::Accumulate(), count, SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), rnnlm::i, TransitionModel::InitStats(), KALDI_LOG, TransitionModel::MleUpdate(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), AmNnet::Read(), ParseOptions::Read(), TransitionModel::Read(), ParseOptions::Register(), MleTransitionUpdateConfig::Register(), kaldi::nnet2::SetPriors(), Output::Stream(), Input::Stream(), SequentialTableReader< Holder >::Value(), AmNnet::Write(), and TransitionModel::Write().

53  {
54  try {
55  using namespace kaldi;
56  using namespace kaldi::nnet2;
57  typedef kaldi::int32 int32;
58 
59  const char *usage =
60  "Train the transition probabilities of a neural network acoustic model\n"
61  "\n"
62  "Usage: nnet-train-transitions [options] <nnet-in> <alignments-rspecifier> <nnet-out>\n"
63  "e.g.:\n"
64  " nnet-train-transitions 1.nnet \"ark:gunzip -c ali.*.gz|\" 2.nnet\n";
65 
66  bool binary_write = true;
67  bool set_priors = true; // Also set the per-pdf priors in the model.
68  BaseFloat prior_floor = 5.0e-06; // The default was previously 1e-8, but
69  // once we had problems with a pdf-id that
70  // was not being seen in training, being
71  // recognized all the time. This value
72  // seemed to be the smallest prior of the
73  // "seen" pdf-ids in one run.
74  MleTransitionUpdateConfig transition_update_config;
75 
76  ParseOptions po(usage);
77  po.Register("binary", &binary_write, "Write output in binary mode");
78  po.Register("set-priors", &set_priors, "If true, also set priors in neural "
79  "net (we divide by these in test time)");
80  po.Register("prior-floor", &prior_floor, "When setting priors, floor for "
81  "priors");
82  transition_update_config.Register(&po);
83 
84  po.Read(argc, argv);
85 
86  if (po.NumArgs() != 3) {
87  po.PrintUsage();
88  exit(1);
89  }
90 
91  std::string nnet_rxfilename = po.GetArg(1),
92  ali_rspecifier = po.GetArg(2),
93  nnet_wxfilename = po.GetArg(3);
94 
95  TransitionModel trans_model;
96  AmNnet am_nnet;
97  {
98  bool binary_read;
99  Input ki(nnet_rxfilename, &binary_read);
100  trans_model.Read(ki.Stream(), binary_read);
101  am_nnet.Read(ki.Stream(), binary_read);
102  }
103 
104  Vector<double> transition_accs;
105  trans_model.InitStats(&transition_accs);
106 
107  int32 num_done = 0;
108  SequentialInt32VectorReader ali_reader(ali_rspecifier);
109  for (; ! ali_reader.Done(); ali_reader.Next()) {
110  const std::vector<int32> alignment(ali_reader.Value());
111  for (size_t i = 0; i < alignment.size(); i++) {
112  int32 tid = alignment[i];
113  BaseFloat weight = 1.0;
114  trans_model.Accumulate(weight, tid, &transition_accs);
115  }
116  num_done++;
117  }
118  KALDI_LOG << "Accumulated transition stats from " << num_done
119  << " utterances.";
120 
121  {
122  BaseFloat objf_impr, count;
123  trans_model.MleUpdate(transition_accs, transition_update_config,
124  &objf_impr, &count);
125  KALDI_LOG << "Transition model update: average " << (objf_impr/count)
126  << " log-like improvement per frame over " << count
127  << " frames.";
128  }
129 
130  if (set_priors) {
131  KALDI_LOG << "Setting priors of pdfs in the model.";
132  SetPriors(trans_model, transition_accs, prior_floor, &am_nnet);
133  }
134 
135  {
136  Output ko(nnet_wxfilename, binary_write);
137  trans_model.Write(ko.Stream(), binary_write);
138  am_nnet.Write(ko.Stream(), binary_write);
139  }
140  KALDI_LOG << "Trained transitions of neural network model and wrote it to "
141  << nnet_wxfilename;
142  return 0;
143  } catch(const std::exception &e) {
144  std::cerr << e.what() << '\n';
145  return -1;
146  }
147 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void MleUpdate(const Vector< double > &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out)
Does Maximum Likelihood estimation.
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
const size_t count
float BaseFloat
Definition: kaldi-types.h:29
void InitStats(Vector< double > *stats) const
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
void SetPriors(const TransitionModel &tmodel, const Vector< double > &transition_accs, double prior_floor, AmNnet *am_nnet)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Write(std::ostream &os, bool binary) const
void Register(OptionsItf *opts)
#define KALDI_LOG
Definition: kaldi-error.h:153