lattice-lmrescore-kaldi-rnnlm-pruned.cc
Go to the documentation of this file.
1 // latbin/lattice-lmrescore-kaldi-rnnlm-pruned.cc
2 
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 // 2017 Hainan Xu
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 
22 #include "base/kaldi-common.h"
23 #include "fstext/fstext-lib.h"
24 #include "rnnlm/rnnlm-lattice-rescoring.h"
25 #include "lm/const-arpa-lm.h"
26 #include "util/common-utils.h"
27 #include "nnet3/nnet-utils.h"
28 #include "lat/kaldi-lattice.h"
29 #include "lat/lattice-functions.h"
31 
32 int main(int argc, char *argv[]) {
33  try {
34  using namespace kaldi;
35  typedef kaldi::int32 int32;
36  typedef kaldi::int64 int64;
37  using fst::SymbolTable;
38  using fst::VectorFst;
39  using fst::StdArc;
40  using fst::ReadFstKaldi;
41 
42  const char *usage =
43  "Rescores lattice with kaldi-rnnlm. This script is called from \n"
44  "scripts/rnnlm/lmrescore_pruned.sh. An example for rescoring \n"
45  "lattices is at egs/swbd/s5c/local/rnnlm/run_lstm.sh \n"
46  "\n"
47  "Usage: lattice-lmrescore-kaldi-rnnlm-pruned [options] \\\n"
48  " <old-lm-rxfilename> <embedding-file> \\\n"
49  " <raw-rnnlm-rxfilename> \\\n"
50  " <lattice-rspecifier> <lattice-wspecifier>\n"
51  " e.g.: lattice-lmrescore-kaldi-rnnlm-pruned --lm-scale=-1.0 fst_words.txt \\\n"
52  " --bos-symbol=1 --eos-symbol=2 \\\n"
53  " data/lang_test/G.fst word_embedding.mat \\\n"
54  " final.raw ark:in.lats ark:out.lats\n\n"
55  " lattice-lmrescore-kaldi-rnnlm-pruned --lm-scale=-1.0 fst_words.txt \\\n"
56  " --bos-symbol=1 --eos-symbol=2 \\\n"
57  " data/lang_test_fg/G.carpa word_embedding.mat \\\n"
58  " final.raw ark:in.lats ark:out.lats\n";
59 
60  ParseOptions po(usage);
61  rnnlm::RnnlmComputeStateComputationOptions opts;
62  ComposeLatticePrunedOptions compose_opts;
63 
64  int32 max_ngram_order = 3;
65  BaseFloat lm_scale = 0.5;
66  BaseFloat acoustic_scale = 0.1;
67  bool use_carpa = false;
68 
69  po.Register("lm-scale", &lm_scale, "Scaling factor for <lm-to-add>; its negative "
70  "will be applied to <lm-to-subtract>.");
71  po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic "
72  "probabilities (e.g. 0.1 for non-chain systems); important because "
73  "of its effect on pruning.");
74  po.Register("max-ngram-order", &max_ngram_order,
75  "If positive, allow RNNLM histories longer than this to be identified "
76  "with each other for rescoring purposes (an approximation that "
77  "saves time and reduces output lattice size).");
78  po.Register("use-const-arpa", &use_carpa, "If true, read the old-LM file "
79  "as a const-arpa file as opposed to an FST file");
80 
81  opts.Register(&po);
82  compose_opts.Register(&po);
83 
84  po.Read(argc, argv);
85 
86  if (po.NumArgs() != 5) {
87  po.PrintUsage();
88  exit(1);
89  }
90 
91  if (opts.bos_index == -1 || opts.eos_index == -1) {
92  KALDI_ERR << "must set --bos-symbol and --eos-symbol options";
93  }
94 
95  std::string lm_to_subtract_rxfilename, lats_rspecifier,
96  word_embedding_rxfilename, rnnlm_rxfilename, lats_wspecifier;
97 
98  lm_to_subtract_rxfilename = po.GetArg(1),
99  word_embedding_rxfilename = po.GetArg(2);
100  rnnlm_rxfilename = po.GetArg(3);
101  lats_rspecifier = po.GetArg(4);
102  lats_wspecifier = po.GetArg(5);
103 
104  // for G.fst
105  fst::ScaleDeterministicOnDemandFst *lm_to_subtract_det_scale = NULL;
106  fst::BackoffDeterministicOnDemandFst<StdArc> *lm_to_subtract_det_backoff = NULL;
107  VectorFst<StdArc> *lm_to_subtract_fst = NULL;
108 
109  // for G.carpa
110  ConstArpaLm* const_arpa = NULL;
111  fst::DeterministicOnDemandFst<StdArc> *carpa_lm_to_subtract_fst = NULL;
112 
113  KALDI_LOG << "Reading old LMs...";
114  if (use_carpa) {
115  const_arpa = new ConstArpaLm();
116  ReadKaldiObject(lm_to_subtract_rxfilename, const_arpa);
117  carpa_lm_to_subtract_fst = new ConstArpaLmDeterministicFst(*const_arpa);
118  lm_to_subtract_det_scale
119  = new fst::ScaleDeterministicOnDemandFst(-lm_scale,
120  carpa_lm_to_subtract_fst);
121  } else {
122  lm_to_subtract_fst = fst::ReadAndPrepareLmFst(
123  lm_to_subtract_rxfilename);
124  lm_to_subtract_det_backoff =
125  new fst::BackoffDeterministicOnDemandFst<StdArc>(*lm_to_subtract_fst);
126  lm_to_subtract_det_scale =
128  lm_to_subtract_det_backoff);
129  }
130 
132  ReadKaldiObject(rnnlm_rxfilename, &rnnlm);
133 
134  KALDI_ASSERT(IsSimpleNnet(rnnlm));
135 
136  CuMatrix<BaseFloat> word_embedding_mat;
137  ReadKaldiObject(word_embedding_rxfilename, &word_embedding_mat);
138 
139  const rnnlm::RnnlmComputeStateInfo info(opts, rnnlm, word_embedding_mat);
140 
141  // Reads and writes as compact lattice.
142  SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier);
143  CompactLatticeWriter compact_lattice_writer(lats_wspecifier);
144 
145  int32 num_done = 0, num_err = 0;
146 
147  rnnlm::KaldiRnnlmDeterministicFst* lm_to_add_orig =
148  new rnnlm::KaldiRnnlmDeterministicFst(max_ngram_order, info);
149 
150  for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {
152  new fst::ScaleDeterministicOnDemandFst(lm_scale, lm_to_add_orig);
153 
154  std::string key = compact_lattice_reader.Key();
155  CompactLattice clat = compact_lattice_reader.Value();
156  compact_lattice_reader.FreeCurrent();
157 
158  // Before composing with the LM FST, we scale the lattice weights
159  // by the inverse of "lm_scale". We'll later scale by "lm_scale".
160  // We do it this way so we can determinize and it will give the
161  // right effect (taking the "best path" through the LM) regardless
162  // of the sign of lm_scale.
163  if (acoustic_scale != 1.0) {
164  fst::ScaleLattice(fst::AcousticLatticeScale(acoustic_scale), &clat);
165  }
167 
169  lm_to_subtract_det_scale, lm_to_add);
170 
171  // Composes lattice with language model.
172  CompactLattice composed_clat;
173  ComposeCompactLatticePruned(compose_opts, clat,
174  &combined_lms, &composed_clat);
175 
176  lm_to_add_orig->Clear();
177 
178  if (composed_clat.NumStates() == 0) {
179  // Something went wrong. A warning will already have been printed.
180  num_err++;
181  } else {
182  if (acoustic_scale != 1.0) {
183  if (acoustic_scale == 0.0)
184  KALDI_ERR << "Acoustic scale cannot be zero.";
185  fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale),
186  &composed_clat);
187  }
188  compact_lattice_writer.Write(key, composed_clat);
189  num_done++;
190  }
191  delete lm_to_add;
192  }
193 
194  delete lm_to_subtract_fst;
195  delete lm_to_add_orig;
196  delete lm_to_subtract_det_backoff;
197  delete lm_to_subtract_det_scale;
198 
199  delete const_arpa;
200  delete carpa_lm_to_subtract_fst;
201 
202  KALDI_LOG << "Overall, succeeded for " << num_done
203  << " lattices, failed for " << num_err;
204  return (num_done != 0 ? 0 : 1);
205  } catch(const std::exception &e) {
206  std::cerr << e.what();
207  return -1;
208  }
209 }
This class wraps an Fst, representing a language model, using the interface for "BackoffDeterministic...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Class ScaleDeterministicOnDemandFst takes another DeterministicOnDemandFst and scales the weights (li...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
fst::StdArc StdArc
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
This file contains some miscellaneous functions dealing with class Nnet.
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
fst::VectorFst< fst::StdArc > * ReadAndPrepareLmFst(std::string rxfilename)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int main(int argc, char *argv[])
void ComposeCompactLatticePruned(const ComposeLatticePrunedOptions &opts, const CompactLattice &clat, fst::DeterministicOnDemandFst< fst::StdArc > *det_fst, CompactLattice *composed_clat)
Does pruned composition of a lattice &#39;clat&#39; with a DeterministicOnDemandFst &#39;det_fst&#39;; implements LM ...
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
int NumArgs() const
Number of positional parameters (c.f. argc-1).
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
Definition: nnet-utils.cc:52
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void ReadFstKaldi(std::istream &is, bool binary, VectorFst< Arc > *fst)
void TopSortCompactLatticeIfNeeded(CompactLattice *clat)
Topologically sort the compact lattice if not already topologically sorted.
#define KALDI_LOG
Definition: kaldi-error.h:153
This class wraps a ConstArpaLm format language model with the interface defined in DeterministicOnDem...