nnet3-align-compiled.cc File Reference
Include dependency graph for nnet3-align-compiled.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 35 of file nnet3-align-compiled.cc.

References NnetSimpleComputationOptions::acoustic_scale, kaldi::AddTransitionProbs(), kaldi::AlignUtteranceWrapper(), kaldi::nnet3::CollapseModel(), SequentialTableReader< Holder >::Done(), SequentialTableReader< Holder >::FreeCurrent(), ParseOptions::GetArg(), AmNnetSimple::GetNnet(), ParseOptions::GetOptArg(), RandomAccessTableReader< Holder >::HasKey(), RandomAccessTableReaderMapped< Holder >::HasKey(), KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), NnetSimpleComputationOptions::optimize_config, ParseOptions::PrintUsage(), AmNnetSimple::Read(), ParseOptions::Read(), TransitionModel::Read(), AlignConfig::Register(), NnetSimpleComputationOptions::Register(), ParseOptions::Register(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), SequentialTableReader< Holder >::Value(), and RandomAccessTableReaderMapped< Holder >::Value().

35  {
36  try {
37  using namespace kaldi;
38  using namespace kaldi::nnet3;
39  typedef kaldi::int32 int32;
40  using fst::SymbolTable;
41  using fst::VectorFst;
42  using fst::StdArc;
43 
44  const char *usage =
45  "Align features given nnet3 neural net model\n"
46  "Usage: nnet3-align-compiled [options] <nnet-in> <graphs-rspecifier> "
47  "<features-rspecifier> <alignments-wspecifier>\n"
48  "e.g.: \n"
49  " nnet3-align-compiled 1.mdl ark:graphs.fsts scp:train.scp ark:1.ali\n"
50  "or:\n"
51  " compile-train-graphs tree 1.mdl lex.fst 'ark:sym2int.pl -f 2- words.txt text|' \\\n"
52  " ark:- | nnet3-align-compiled 1.mdl ark:- scp:train.scp t, ark:1.ali\n";
53 
54  ParseOptions po(usage);
55  AlignConfig align_config;
56  NnetSimpleComputationOptions decodable_opts;
57  std::string use_gpu = "yes";
58  BaseFloat transition_scale = 1.0;
59  BaseFloat self_loop_scale = 1.0;
60  std::string per_frame_acwt_wspecifier;
61 
62  std::string ivector_rspecifier,
63  online_ivector_rspecifier,
64  utt2spk_rspecifier;
65  int32 online_ivector_period = 0;
66  align_config.Register(&po);
67  decodable_opts.Register(&po);
68 
69  po.Register("use-gpu", &use_gpu,
70  "yes|no|optional|wait, only has effect if compiled with CUDA");
71  po.Register("transition-scale", &transition_scale,
72  "Transition-probability scale [relative to acoustics]");
73  po.Register("self-loop-scale", &self_loop_scale,
74  "Scale of self-loop versus non-self-loop "
75  "log probs [relative to acoustics]");
76  po.Register("write-per-frame-acoustic-loglikes", &per_frame_acwt_wspecifier,
77  "Wspecifier for table of vectors containing the acoustic log-likelihoods "
78  "per frame for each utterance. E.g. ark:foo/per_frame_logprobs.1.ark");
79  po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
80  "iVectors as vectors (i.e. not estimated online); per utterance "
81  "by default, or per speaker if you provide the --utt2spk option.");
82  po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
83  "iVectors estimated online, as matrices. If you supply this,"
84  " you must set the --online-ivector-period option.");
85  po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
86  "between iVectors in matrices supplied to the --online-ivectors "
87  "option");
88  po.Read(argc, argv);
89 
90  if (po.NumArgs() < 4 || po.NumArgs() > 5) {
91  po.PrintUsage();
92  exit(1);
93  }
94 
95 #if HAVE_CUDA==1
96  CuDevice::Instantiate().SelectGpuId(use_gpu);
97 #endif
98 
99  std::string model_in_filename = po.GetArg(1),
100  fst_rspecifier = po.GetArg(2),
101  feature_rspecifier = po.GetArg(3),
102  alignment_wspecifier = po.GetArg(4),
103  scores_wspecifier = po.GetOptArg(5);
104 
105  int num_done = 0, num_err = 0, num_retry = 0;
106  double tot_like = 0.0;
107  kaldi::int64 frame_count = 0;
108 
109 
110  {
111  TransitionModel trans_model;
112  AmNnetSimple am_nnet;
113  {
114  bool binary;
115  Input ki(model_in_filename, &binary);
116  trans_model.Read(ki.Stream(), binary);
117  am_nnet.Read(ki.Stream(), binary);
118  }
119  SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
120  SetDropoutTestMode(true, &(am_nnet.GetNnet()));
121  CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));
122  // this compiler object allows caching of computations across
123  // different utterances.
124  CachingOptimizingCompiler compiler(am_nnet.GetNnet(),
125  decodable_opts.optimize_config);
126 
127  RandomAccessBaseFloatMatrixReader online_ivector_reader(
128  online_ivector_rspecifier);
130  ivector_rspecifier, utt2spk_rspecifier);
131 
132 
133  SequentialTableReader<fst::VectorFstHolder> fst_reader(fst_rspecifier);
134  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
135  Int32VectorWriter alignment_writer(alignment_wspecifier);
136  BaseFloatWriter scores_writer(scores_wspecifier);
137  BaseFloatVectorWriter per_frame_acwt_writer(per_frame_acwt_wspecifier);
138 
139  for (; !fst_reader.Done(); fst_reader.Next()) {
140  std::string utt = fst_reader.Key();
141  if (!feature_reader.HasKey(utt)) {
142  KALDI_WARN << "No features for utterance " << utt;
143  num_err++;
144  continue;
145  }
146  const Matrix<BaseFloat> &features = feature_reader.Value(utt);
147  VectorFst<StdArc> decode_fst(fst_reader.Value());
148  fst_reader.FreeCurrent(); // this stops copy-on-write of the fst
149  // by deleting the fst inside the reader, since we're about to mutate
150  // the fst by adding transition probs.
151 
152  if (features.NumRows() == 0) {
153  KALDI_WARN << "Zero-length utterance: " << utt;
154  num_err++;
155  continue;
156  }
157 
158  const Matrix<BaseFloat> *online_ivectors = NULL;
159  const Vector<BaseFloat> *ivector = NULL;
160  if (!ivector_rspecifier.empty()) {
161  if (!ivector_reader.HasKey(utt)) {
162  KALDI_WARN << "No iVector available for utterance " << utt;
163  num_err++;
164  continue;
165  } else {
166  ivector = &ivector_reader.Value(utt);
167  }
168  }
169  if (!online_ivector_rspecifier.empty()) {
170  if (!online_ivector_reader.HasKey(utt)) {
171  KALDI_WARN << "No online iVector available for utterance " << utt;
172  num_err++;
173  continue;
174  } else {
175  online_ivectors = &online_ivector_reader.Value(utt);
176  }
177  }
178 
179  { // Add transition-probs to the FST.
180  std::vector<int32> disambig_syms; // empty.
181  AddTransitionProbs(trans_model, disambig_syms,
182  transition_scale, self_loop_scale,
183  &decode_fst);
184  }
185 
186  DecodableAmNnetSimple nnet_decodable(
187  decodable_opts, trans_model, am_nnet,
188  features, ivector, online_ivectors,
189  online_ivector_period, &compiler);
190 
191  AlignUtteranceWrapper(align_config, utt,
192  decodable_opts.acoustic_scale,
193  &decode_fst, &nnet_decodable,
194  &alignment_writer, &scores_writer,
195  &num_done, &num_err, &num_retry,
196  &tot_like, &frame_count, &per_frame_acwt_writer);
197  }
198  KALDI_LOG << "Overall log-likelihood per frame is "
199  << (tot_like/frame_count)
200  << " over " << frame_count<< " frames.";
201  KALDI_LOG << "Retried " << num_retry << " out of "
202  << (num_done + num_err) << " utterances.";
203  KALDI_LOG << "Done " << num_done << ", errors on " << num_err;
204  }
205 
206 #if HAVE_CUDA==1
207  CuDevice::Instantiate().PrintProfile();
208 #endif
209  return (num_done != 0 ? 0 : 1);
210  } catch(const std::exception &e) {
211  std::cerr << e.what();
212  return -1;
213  }
214 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
void Register(OptionsItf *opts)
fst::StdArc StdArc
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
const Nnet & GetNnet() const
void Read(std::istream &is, bool binary)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
void AddTransitionProbs(const TransitionModel &trans_model, const std::vector< int32 > &disambig_syms, BaseFloat transition_scale, BaseFloat self_loop_scale, fst::VectorFst< fst::StdArc > *fst)
Adds transition-probs, with the supplied scales (see Scaling of transition and acoustic probabilities...
Definition: hmm-utils.cc:1088
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
void AlignUtteranceWrapper(const AlignConfig &config, const std::string &utt, BaseFloat acoustic_scale, fst::VectorFst< fst::StdArc > *fst, DecodableInterface *decodable, Int32VectorWriter *alignment_writer, BaseFloatWriter *scores_writer, int32 *num_done, int32 *num_error, int32 *num_retried, double *tot_like, int64 *frame_count, BaseFloatVectorWriter *per_frame_acwt_writer)
AlignUtteranceWapper is a wrapper for alignment code used in training, that is called from many diffe...
#define KALDI_LOG
Definition: kaldi-error.h:153
Config class for the CollapseModel function.
Definition: nnet-utils.h:240