nnet-align-compiled.cc
Go to the documentation of this file.
1 // nnet2bin/nnet-align-compiled.cc
2 
3 // Copyright 2009-2012 Microsoft Corporation
4 // Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "gmm/am-diag-gmm.h"
24 #include "hmm/transition-model.h"
25 #include "hmm/hmm-utils.h"
26 #include "fstext/fstext-lib.h"
30 #include "lat/kaldi-lattice.h"
31 
32 int main(int argc, char *argv[]) {
33  try {
34  using namespace kaldi;
35  using namespace kaldi::nnet2;
36  typedef kaldi::int32 int32;
37  using fst::SymbolTable;
38  using fst::VectorFst;
39  using fst::StdArc;
40 
41  const char *usage =
42  "Align features given neural-net-based model\n"
43  "Usage: nnet-align-compiled [options] <model-in> <graphs-rspecifier> "
44  "<feature-rspecifier> <alignments-wspecifier>\n"
45  "e.g.: \n"
46  " nnet-align-compiled 1.mdl ark:graphs.fsts scp:train.scp ark:1.ali\n"
47  "or:\n"
48  " compile-train-graphs tree 1.mdl lex.fst 'ark:sym2int.pl -f 2- words.txt text|' \\\n"
49  " ark:- | nnet-align-compiled 1.mdl ark:- scp:train.scp t, ark:1.ali\n";
50 
51  ParseOptions po(usage);
52  AlignConfig align_config;
53  std::string use_gpu = "yes";
54  BaseFloat acoustic_scale = 1.0;
55  BaseFloat transition_scale = 1.0;
56  BaseFloat self_loop_scale = 1.0;
57  std::string per_frame_acwt_wspecifier;
58 
59  align_config.Register(&po);
60  po.Register("transition-scale", &transition_scale,
61  "Transition-probability scale [relative to acoustics]");
62  po.Register("acoustic-scale", &acoustic_scale,
63  "Scaling factor for acoustic likelihoods");
64  po.Register("self-loop-scale", &self_loop_scale,
65  "Scale of self-loop versus non-self-loop "
66  "log probs [relative to acoustics]");
67  po.Register("write-per-frame-acoustic-loglikes", &per_frame_acwt_wspecifier,
68  "Wspecifier for table of vectors containing the acoustic log-likelihoods "
69  "per frame for each utterance. E.g. ark:foo/per_frame_logprobs.1.ark");
70  po.Register("use-gpu", &use_gpu,
71  "yes|no|optional|wait, only has effect if compiled with CUDA");
72  po.Read(argc, argv);
73 
74  if (po.NumArgs() < 4 || po.NumArgs() > 5) {
75  po.PrintUsage();
76  exit(1);
77  }
78 
79 #if HAVE_CUDA==1
80  CuDevice::Instantiate().SelectGpuId(use_gpu);
81 #endif
82 
83  std::string model_in_filename = po.GetArg(1),
84  fst_rspecifier = po.GetArg(2),
85  feature_rspecifier = po.GetArg(3),
86  alignment_wspecifier = po.GetArg(4),
87  scores_wspecifier = po.GetOptArg(5);
88 
89  int num_done = 0, num_err = 0, num_retry = 0;
90  double tot_like = 0.0;
91  kaldi::int64 frame_count = 0;
92 
93  {
94  TransitionModel trans_model;
95  AmNnet am_nnet;
96  {
97  bool binary;
98  Input ki(model_in_filename, &binary);
99  trans_model.Read(ki.Stream(), binary);
100  am_nnet.Read(ki.Stream(), binary);
101  }
102 
103  SequentialTableReader<fst::VectorFstHolder> fst_reader(fst_rspecifier);
104  RandomAccessBaseFloatCuMatrixReader feature_reader(feature_rspecifier);
105  Int32VectorWriter alignment_writer(alignment_wspecifier);
106  BaseFloatWriter scores_writer(scores_wspecifier);
107  BaseFloatVectorWriter per_frame_acwt_writer(per_frame_acwt_wspecifier);
108 
109  for (; !fst_reader.Done(); fst_reader.Next()) {
110  std::string utt = fst_reader.Key();
111  if (!feature_reader.HasKey(utt)) {
112  KALDI_WARN << "No features for utterance " << utt;
113  num_err++;
114  continue;
115  }
116  const CuMatrix<BaseFloat> &features = feature_reader.Value(utt);
117  VectorFst<StdArc> decode_fst(fst_reader.Value());
118  fst_reader.FreeCurrent(); // this stops copy-on-write of the fst
119  // by deleting the fst inside the reader, since we're about to mutate
120  // the fst by adding transition probs.
121 
122  if (features.NumRows() == 0) {
123  KALDI_WARN << "Zero-length utterance: " << utt;
124  num_err++;
125  continue;
126  }
127 
128  { // Add transition-probs to the FST.
129  std::vector<int32> disambig_syms; // empty.
130  AddTransitionProbs(trans_model, disambig_syms,
131  transition_scale, self_loop_scale,
132  &decode_fst);
133  }
134 
135  bool pad_input = true;
136  DecodableAmNnet nnet_decodable(trans_model, am_nnet, features,
137  pad_input, acoustic_scale);
138 
139  AlignUtteranceWrapper(align_config, utt,
140  acoustic_scale, &decode_fst, &nnet_decodable,
141  &alignment_writer, &scores_writer,
142  &num_done, &num_err, &num_retry,
143  &tot_like, &frame_count, &per_frame_acwt_writer);
144  }
145  KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count)
146  << " over " << frame_count<< " frames.";
147  KALDI_LOG << "Retried " << num_retry << " out of "
148  << (num_done + num_err) << " utterances.";
149  KALDI_LOG << "Done " << num_done << ", errors on " << num_err;
150  }
151 #if HAVE_CUDA==1
152  CuDevice::Instantiate().PrintProfile();
153 #endif
154  return (num_done != 0 ? 0 : 1);
155  } catch(const std::exception &e) {
156  std::cerr << e.what();
157  return -1;
158  }
159 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Register(OptionsItf *opts)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
void AddTransitionProbs(const TransitionModel &trans_model, const std::vector< int32 > &disambig_syms, BaseFloat transition_scale, BaseFloat self_loop_scale, fst::VectorFst< fst::StdArc > *fst)
Adds transition-probs, with the supplied scales (see Scaling of transition and acoustic probabilities...
Definition: hmm-utils.cc:1088
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
DecodableAmNnet is a decodable object that decodes with a neural net acoustic model of type AmNnet...
int main(int argc, char *argv[])
MatrixIndexT NumRows() const
Dimensions.
Definition: cu-matrix.h:215
void AlignUtteranceWrapper(const AlignConfig &config, const std::string &utt, BaseFloat acoustic_scale, fst::VectorFst< fst::StdArc > *fst, DecodableInterface *decodable, Int32VectorWriter *alignment_writer, BaseFloatWriter *scores_writer, int32 *num_done, int32 *num_error, int32 *num_retried, double *tot_like, int64 *frame_count, BaseFloatVectorWriter *per_frame_acwt_writer)
AlignUtteranceWapper is a wrapper for alignment code used in training, that is called from many diffe...
#define KALDI_LOG
Definition: kaldi-error.h:153
std::string GetOptArg(int param) const