lattice-lmrescore-kaldi-rnnlm.cc
Go to the documentation of this file.
1 // latbin/lattice-lmrescore-kaldi-rnnlm.cc
2 
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 // 2017 Hainan Xu
5 // 2017 Yiming Wang
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 
23 #include "base/kaldi-common.h"
24 #include "fstext/fstext-lib.h"
25 #include "lat/kaldi-lattice.h"
26 #include "lat/lattice-functions.h"
27 #include "rnnlm/rnnlm-lattice-rescoring.h"
28 #include "util/common-utils.h"
29 #include "nnet3/nnet-utils.h"
30 
31 int main(int argc, char *argv[]) {
32  try {
33  using namespace kaldi;
34  typedef kaldi::int32 int32;
35  typedef kaldi::int64 int64;
36 
37  const char *usage =
38  "Rescores lattice with kaldi-rnnlm. This script is called from \n"
39  "scripts/rnnlm/lmrescore.sh. An example for rescoring \n"
40  "lattices is at egs/swbd/s5c/local/rnnlm/run_lstm.sh \n"
41  "\n"
42  "Usage: lattice-lmrescore-kaldi-rnnlm [options] \\\n"
43  " <embedding-file> <raw-rnnlm-rxfilename> \\\n"
44  " <lattice-rspecifier> <lattice-wspecifier>\n"
45  " e.g.: lattice-lmrescore-kaldi-rnnlm --lm-scale=-1.0 \\\n"
46  " word_embedding.mat \\\n"
47  " --bos-symbol=1 --eos-symbol=2 \\\n"
48  " final.raw ark:in.lats ark:out.lats\n";
49 
50  ParseOptions po(usage);
51  rnnlm::RnnlmComputeStateComputationOptions opts;
52 
53  int32 max_ngram_order = 3;
54  BaseFloat lm_scale = 1.0;
55 
56  po.Register("lm-scale", &lm_scale, "Scaling factor for language model "
57  "costs");
58  po.Register("max-ngram-order", &max_ngram_order,
59  "If positive, allow RNNLM histories longer than this to be identified "
60  "with each other for rescoring purposes (an approximation that "
61  "saves time and reduces output lattice size).");
62  opts.Register(&po);
63 
64  po.Read(argc, argv);
65 
66  if (po.NumArgs() != 4) {
67  po.PrintUsage();
68  exit(1);
69  }
70 
71  if (opts.bos_index == -1 || opts.eos_index == -1) {
72  KALDI_ERR << "You must set --bos-symbol and --eos-symbol options";
73  }
74 
75  std::string word_embedding_rxfilename = po.GetArg(1),
76  rnnlm_rxfilename = po.GetArg(2),
77  lats_rspecifier = po.GetArg(3),
78  lats_wspecifier = po.GetArg(4);
79 
81  ReadKaldiObject(rnnlm_rxfilename, &rnnlm);
82 
83  KALDI_ASSERT(IsSimpleNnet(rnnlm));
84 
85  CuMatrix<BaseFloat> word_embedding_mat;
86  ReadKaldiObject(word_embedding_rxfilename, &word_embedding_mat);
87 
88  const rnnlm::RnnlmComputeStateInfo info(opts, rnnlm, word_embedding_mat);
89 
90  // Reads and writes as compact lattice.
91  SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier);
92  CompactLatticeWriter compact_lattice_writer(lats_wspecifier);
93 
94  int32 n_done = 0, n_fail = 0;
95 
96  rnnlm::KaldiRnnlmDeterministicFst rnnlm_fst(max_ngram_order, info);
97 
98  for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {
99  std::string key = compact_lattice_reader.Key();
100  CompactLattice &clat = compact_lattice_reader.Value();
101 
102  if (lm_scale != 0.0) {
103  // Before composing with the LM FST, we scale the lattice weights
104  // by the inverse of "lm_scale". We'll later scale by "lm_scale".
105  // We do it this way so we can determinize and it will give the
106  // right effect (taking the "best path" through the LM) regardless
107  // of the sign of lm_scale.
108  fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &clat);
109  ArcSort(&clat, fst::OLabelCompare<CompactLatticeArc>());
110 
111  // Wraps the rnnlm into FST. We re-create it for each lattice to prevent
112  // memory usage increasing with time.
113 
114  // Composes lattice with language model.
115  CompactLattice composed_clat;
116  ComposeCompactLatticeDeterministic(clat, &rnnlm_fst, &composed_clat);
117 
118  // Determinizes the composed lattice.
119  Lattice composed_lat;
120  ConvertLattice(composed_clat, &composed_lat);
121  Invert(&composed_lat);
122  CompactLattice determinized_clat;
123  DeterminizeLattice(composed_lat, &determinized_clat);
124  fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), &determinized_clat);
125  if (determinized_clat.Start() == fst::kNoStateId) {
126  KALDI_WARN << "Empty lattice for utterance " << key
127  << " (incompatible LM?)";
128  n_fail++;
129  } else {
130  compact_lattice_writer.Write(key, determinized_clat);
131  n_done++;
132  }
133  } else {
134  // Zero scale so nothing to do.
135  n_done++;
136  compact_lattice_writer.Write(key, clat);
137  }
138  rnnlm_fst.Clear();
139  }
140 
141  KALDI_LOG << "Done " << n_done << " lattices, failed for " << n_fail;
142  return (n_done != 0 ? 0 : 1);
143  } catch(const std::exception &e) {
144  std::cerr << e.what();
145  return -1;
146  }
147 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
This file contains some miscellaneous functions dealing with class Nnet.
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void ComposeCompactLatticeDeterministic(const CompactLattice &clat, fst::DeterministicOnDemandFst< fst::StdArc > *det_fst, CompactLattice *composed_clat)
This function Composes a CompactLattice format lattice with a DeterministicOnDemandFst<fst::StdFst> f...
fst::VectorFst< LatticeArc > Lattice
Definition: kaldi-lattice.h:44
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
std::vector< std::vector< double > > GraphLatticeScale(double lmwt)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
Definition: nnet-utils.cc:52
int main(int argc, char *argv[])
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
bool DeterminizeLattice(const Fst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< Weight > > *ofst, DeterminizeLatticeOptions opts, bool *debug_ptr)
This function implements the normal version of DeterminizeLattice, in which the output strings are re...
#define KALDI_LOG
Definition: kaldi-error.h:153