lattice-combine.cc
Go to the documentation of this file.
1 // latbin/lattice-combine.cc
2 
3 // Copyright 2012 Arnab Ghoshal
4 // 2016 Johns Hopkins University (author: Daniel Povey)
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 // This program is for system combination using MBR decoding as described in:
22 // "Minimum Bayes Risk decoding and system combination based on a recursion for
23 // edit distance", Haihua Xu, Daniel Povey, Lidia Mangu and Jie Zhu, Computer
24 // Speech and Language, 2011. However, instead of averaging the posteriors, as
25 // described in the paper, this removes the total backward probability from the
26 // individual lattices being combined and outputs the union of them. The output
27 // should be used with lattice-mbr-decode (without any acoustic or LM scaling)
28 // or with lattice-to-ctm-conf with --decode-mbr=true (also without any scaling)
29 
30 // IMPORTANT CAVEAT: the total backward probability (which is a float) is
31 // removed from value1_ of arc weight. So graph scores are no longer correct
32 // but instead only the combined acoustic and graph scores are valid. So no
33 // acoustic or LM scaling should be done with the output of this program.
34 
35 #include <string>
36 using std::string;
37 #include <vector>
38 using std::vector;
39 
40 #include "util/common-utils.h"
41 #include "lat/lattice-functions.h"
42 #include "lat/kaldi-lattice.h"
43 #include "lat/sausages.h"
44 
45 namespace kaldi {
46 
47 // This removes the total weight from a CompactLattice. Since the total backward
48 // score is in log likelihood domain, and the lattice weights are in negative
49 // log likelihood domain, the total weight is *added* to the weight of the final
50 // states. This is equivalent to dividing the probability of each path by the
51 // total probability over all paths. There is an additional weight to control
52 // the relative contribution of individual lattices-- the log of the weight will
53 // become the total weight of the lattice.
55  if (weight <= 0.0) {
56  KALDI_WARN << "Weights must be positive; found: " << weight;
57  return false;
58  }
59 
60  if (clat->Properties(fst::kTopSorted, false) == 0) {
61  if (fst::TopSort(clat) == false) {
62  KALDI_WARN << "Cycles detected in lattice: cannot normalize.";
63  return false;
64  }
65  }
66 
67  vector<double> beta;
68  if (!ComputeCompactLatticeBetas(*clat, &beta)) {
69  KALDI_WARN << "Failed to compute backward probabilities on lattice.";
70  return false;
71  }
72 
74  StateId start = clat->Start(); // Should be 0
75  BaseFloat total_backward_cost = beta[start];
76 
77  total_backward_cost -= Log(weight);
78 
79  for (fst::StateIterator<CompactLattice> sit(*clat); !sit.Done(); sit.Next()) {
80  CompactLatticeWeight f = clat->Final(sit.Value());
81  LatticeWeight w = f.Weight();
82  w.SetValue1(w.Value1() + total_backward_cost);
83  f.SetWeight(w);
84  clat->SetFinal(sit.Value(), f);
85  }
86  return true;
87 }
88 
89 // This is a wrapper for SplitStringToFloats, with added checks to make sure
90 // the weights are valid probabilities.
91 void SplitStringToWeights(const string &full, const char *delim,
92  vector<BaseFloat> *out) {
93  vector<BaseFloat> tmp;
94  SplitStringToFloats(full, delim, true /*omit empty strings*/, &tmp);
95  if (tmp.size() != out->size()) {
96  KALDI_WARN << "Expecting " << out->size() << " weights, found " << tmp.size()
97  << ": using uniform weights.";
98  return;
99  }
100  BaseFloat sum = 0;
101  for (vector<BaseFloat>::const_iterator itr = tmp.begin();
102  itr != tmp.end(); ++itr) {
103  if (*itr < 0.0) {
104  KALDI_WARN << "Cannot use negative weight: " << *itr << "; input string: "
105  << full << "\n\tUsing uniform weights.";
106  return;
107  }
108  sum += (*itr);
109  }
110  if (sum != 1.0) {
111  KALDI_WARN << "Weights sum to " << sum << " instead of 1: renormalizing";
112  for (vector<BaseFloat>::iterator itr = tmp.begin();
113  itr != tmp.end(); ++itr)
114  (*itr) /= sum;
115  }
116  out->swap(tmp);
117 }
118 
119 } // end namespace kaldi
120 
121 
122 int main(int argc, char *argv[]) {
123  try {
124  using namespace kaldi;
125  typedef kaldi::int32 int32;
126 
127  const char *usage =
128  "Combine lattices generated by different systems by removing the total\n"
129  "cost of all paths (backward cost) from individual lattices and doing\n"
130  "a union of the reweighted lattices. Note: the acoustic and LM scales\n"
131  "that this program applies are not removed before outputting the lattices.\n"
132  "Intended for use in system combination prior to MBR decoding, see comments\n"
133  "in code.\n"
134  "Usage: lattice-combine [options] <lattice-rspecifier1> <lattice-rspecifier2>"
135  " [<lattice-rspecifier3> ... ] <lattice-wspecifier>\n"
136  "E.g.: lattice-combine 'ark:gunzip -c foo/lat.1.gz|' 'ark:gunzip -c bar/lat.1.gz|' ark:- | ...\n";
137 
138  ParseOptions po(usage);
139  BaseFloat acoustic_scale = 1.0, inv_acoustic_scale = 1.0, lm_scale = 1.0;
140  string weight_str;
141  po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for "
142  "acoustic likelihoods");
143  po.Register("inv-acoustic-scale", &inv_acoustic_scale, "An alternative way "
144  "of setting the acoustic scale: you can set its inverse.");
145  po.Register("lm-scale", &lm_scale, "Scaling factor for language model "
146  "probabilities");
147  po.Register("lat-weights", &weight_str, "Colon-separated list of weights "
148  "for each rspecifier (which should sum to 1), e.g. '0.2:0.8'");
149 
150  po.Read(argc, argv);
151 
152  KALDI_ASSERT(acoustic_scale == 1.0 || inv_acoustic_scale == 1.0);
153  if (inv_acoustic_scale != 1.0)
154  acoustic_scale = 1.0 / inv_acoustic_scale;
155 
156 
157  int32 num_args = po.NumArgs();
158  if (num_args < 3) {
159  po.PrintUsage();
160  exit(1);
161  }
162 
163  string lats_rspecifier1 = po.GetArg(1),
164  lats_wspecifier = po.GetArg(num_args);
165 
166  // Output lattice
167  CompactLatticeWriter clat_writer(lats_wspecifier);
168 
169  // Input lattices
170  SequentialCompactLatticeReader clat_reader1(lats_rspecifier1);
171  vector<RandomAccessCompactLatticeReader*> clat_reader_vec(
172  num_args-2, static_cast<RandomAccessCompactLatticeReader*>(NULL));
173  vector<string> clat_rspec_vec(num_args-2);
174  for (int32 i = 2; i < num_args; ++i) {
175  clat_reader_vec[i-2] = new RandomAccessCompactLatticeReader(po.GetArg(i));
176  clat_rspec_vec[i-2] = po.GetArg(i);
177  }
178 
179  vector<BaseFloat> lat_weights(num_args-1, 1.0/(num_args-1));
180  if (!weight_str.empty())
181  SplitStringToWeights(weight_str, ":", &lat_weights);
182 
183  int32 n_utts = 0, n_total_lats = 0, n_success = 0, n_missing = 0,
184  n_other_errors = 0;
185  vector< vector<double> > lat_scale = fst::LatticeScale(lm_scale,
186  acoustic_scale);
187 
188  for (; !clat_reader1.Done(); clat_reader1.Next()) {
189  std::string key = clat_reader1.Key();
190  CompactLattice clat1 = clat_reader1.Value();
191  clat_reader1.FreeCurrent();
192  n_utts++;
193  n_total_lats++;
194  fst::ScaleLattice(lat_scale, &clat1);
195  bool success = CompactLatticeNormalize(&clat1, lat_weights[0]);
196  if (!success) {
197  KALDI_WARN << "Could not normalize lattice for system 1, utterance: "
198  << key;
199  n_other_errors++;
200  continue;
201  }
202 
203  for (int32 i = 0; i < num_args-2; ++i) {
204  if (clat_reader_vec[i]->HasKey(key)) {
205  CompactLattice clat2 = clat_reader_vec[i]->Value(key);
206  n_total_lats++;
207  fst::ScaleLattice(lat_scale, &clat2);
208  success = CompactLatticeNormalize(&clat2, lat_weights[i+1]);
209  if (!success) {
210  KALDI_WARN << "Could not normalize lattice for system "<< (i + 2)
211  << ", utterance: " << key;
212  n_other_errors++;
213  continue;
214  }
215  fst::Union(&clat1, clat2);
216  } else {
217  KALDI_WARN << "No lattice found for utterance " << key << " for "
218  << "system " << (i + 2) << ", rspecifier: "
219  << clat_rspec_vec[i];
220  n_missing++;
221  }
222  }
223 
224  clat_writer.Write(key, clat1);
225  n_success++;
226  }
227 
228  KALDI_LOG << "Processed " << n_utts << " utterances: with a total of "
229  << n_total_lats << " lattices across " << (num_args-1)
230  << " different systems";
231  KALDI_LOG << "Produced output for " << n_success << " utterances; "
232  << n_missing << " total missing lattices and " << n_other_errors
233  << " total lattices had errors in processing.";
234  DeletePointers(&clat_reader_vec);
235  // The success code we choose is that at least one lattice was output,
236  // and more lattices were "all there" than had at least one system missing.
237  return (n_success != 0 && n_missing < (n_success - n_missing) ? 0 : 1);
238  } catch(const std::exception &e) {
239  std::cerr << e.what();
240  return -1;
241  }
242 }
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
bool SplitStringToFloats(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< F > *out)
Definition: text-utils.cc:30
Lattice::StateId StateId
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
float BaseFloat
Definition: kaldi-types.h:29
int main(int argc, char *argv[])
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
double Log(double x)
Definition: kaldi-math.h:100
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
bool ComputeCompactLatticeBetas(const CompactLattice &clat, vector< double > *beta)
bool CompactLatticeNormalize(CompactLattice *clat, BaseFloat weight)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
std::vector< std::vector< double > > LatticeScale(double lmwt, double acwt)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
fst::VectorFst< CompactLatticeArc > CompactLattice
Definition: kaldi-lattice.h:46
RandomAccessTableReader< CompactLatticeHolder > RandomAccessCompactLatticeReader
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void SplitStringToWeights(const string &full, const char *delim, vector< BaseFloat > *out)
#define KALDI_LOG
Definition: kaldi-error.h:153