nnet-compare-hash-discriminative.cc File Reference
Include dependency graph for nnet-compare-hash-discriminative.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 25 of file nnet-compare-hash-discriminative.cc.

References kaldi::ApproxEqual(), DiscriminativeNnetExample::den_lat, SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), KALDI_ASSERT, KALDI_LOG, kaldi::kTrans, fst::LatticeScale(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), fst::ScaleLattice(), kaldi::TraceMatMat(), kaldi::nnet2::UpdateHash(), and SequentialTableReader< Holder >::Value().

25  {
26  try {
27  using namespace kaldi;
28  using namespace kaldi::nnet2;
29  typedef kaldi::int32 int32;
30  typedef kaldi::int64 int64;
31 
32  const char *usage =
33  "Compares two archives of discriminative training examples and checks\n"
34  "that they behave the same way for purposes of discriminative training.\n"
35  "This program was created as a way of testing nnet-get-egs-discriminative\n"
36  "The model is only needed for its transition-model.\n"
37  "\n"
38  "Usage: nnet-compare-hash-discriminative [options] <model-rxfilename> "
39  "<egs-rspecifier1> <egs-rspecifier2>\n"
40  "\n"
41  "Note: options --drop-frames and --criterion should be matched with the\n"
42  "command line of nnet-get-egs-discriminative used to get the examples\n"
43  "nnet-compare-hash-discriminative --drop-frames=true --criterion=mmi ark:1.degs ark:2.degs\n";
44 
45  std::string criterion = "smbr";
46  bool drop_frames = false;
47  bool one_silence_class = false;
48  BaseFloat threshold = 0.002;
49  BaseFloat acoustic_scale = 1.0, lm_scale = 1.0;
50  ParseOptions po(usage);
51 
52  po.Register("acoustic-scale", &acoustic_scale,
53  "Scaling factor for acoustic likelihoods");
54  po.Register("lm-scale", &lm_scale,
55  "Scaling factor for \"graph costs\" (including LM costs)");
56  po.Register("criterion", &criterion, "Training criterion, 'mmi'|'mpfe'|'smbr'");
57  po.Register("drop-frames", &drop_frames, "If true, for MMI training, drop "
58  "frames where num and den do not intersect.");
59  po.Register("one-silence-class", &one_silence_class, "If true, newer "
60  "behavior which will tend to reduce insertions.");
61  po.Register("threshold", &threshold, "Threshold for equality testing "
62  "(relative)");
63 
64  po.Read(argc, argv);
65 
66 
67  if (po.NumArgs() != 3) {
68  po.PrintUsage();
69  exit(1);
70  }
71 
72  std::string model_rxfilename = po.GetArg(1),
73  examples_rspecifier1 = po.GetArg(2),
74  examples_rspecifier2 = po.GetArg(3);
75 
76  int64 num_done1 = 0, num_done2 = 0;
77 
78 
79  TransitionModel tmodel;
80  ReadKaldiObject(model_rxfilename, &tmodel);
81 
82  Matrix<double> hash1, hash2;
83 
84  // some additional diagnostics:
85  double num_weight1 = 0.0, den_weight1 = 0.0, tot_t1 = 0.0;
86  double num_weight2 = 0.0, den_weight2 = 0.0, tot_t2 = 0.0;
87 
89  example_reader1(examples_rspecifier1),
90  example_reader2(examples_rspecifier2);
91 
92  KALDI_LOG << "Computing first hash function";
93  for (; !example_reader1.Done(); example_reader1.Next(), num_done1++) {
94  DiscriminativeNnetExample eg = example_reader1.Value();
95  fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale),
96  &(eg.den_lat));
97  UpdateHash(tmodel, eg, criterion, drop_frames,
98  one_silence_class, &hash1,
99  &num_weight1, &den_weight1, &tot_t1);
100  }
101  KALDI_LOG << "Processed " << num_done1 << " examples.";
102 
103  KALDI_LOG << "Computing second hash function";
104  for (; !example_reader2.Done(); example_reader2.Next(), num_done2++) {
105  DiscriminativeNnetExample eg = example_reader2.Value();
106  fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale),
107  &(eg.den_lat));
108  UpdateHash(tmodel, eg, criterion, drop_frames,
109  one_silence_class, &hash2,
110  &num_weight2, &den_weight2, &tot_t2);
111  }
112  KALDI_LOG << "Processed " << num_done2 << " examples.";
113 
114  double prod1 = TraceMatMat(hash1, hash1, kTrans),
115  prod2 = TraceMatMat(hash2, hash2, kTrans),
116  cross_prod = TraceMatMat(hash1, hash2, kTrans);
117 
118  KALDI_LOG << "Products are as follows (should be the same): prod1 = "
119  << prod1 << ", prod2 = " << prod2 << ", cross_prod = "
120  << cross_prod;
121 
122  KALDI_LOG << "Num-weight1 = " << num_weight1 << ", den-weight1 = "
123  << den_weight1 << ", tot_t1 = " << tot_t1;
124  KALDI_LOG << "Num-weight2 = " << num_weight2 << ", den-weight2 = "
125  << den_weight2 << ", tot_t2 = " << tot_t2;
126 
127  KALDI_ASSERT(ApproxEqual(prod1, prod2, threshold) &&
128  ApproxEqual(prod2, cross_prod, threshold));
129  KALDI_ASSERT(prod1 > 0.0);
130 
131  return 0;
132  } catch(const std::exception &e) {
133  std::cerr << e.what() << '\n';
134  return -1;
135  }
136 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void UpdateHash(const TransitionModel &tmodel, const DiscriminativeNnetExample &eg, std::string criterion, bool drop_frames, bool one_silence_class, Matrix< double > *hash, double *num_weight, double *den_weight, double *tot_t)
This function is used in code that tests the functionality that we provide here, about splitting and ...
kaldi::int32 int32
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
std::vector< std::vector< double > > LatticeScale(double lmwt, double acwt)
CompactLattice den_lat
The denominator lattice.
Definition: nnet-example.h:148
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
This struct is used to store the information we need for discriminative training (MMI or MPE)...
Definition: nnet-example.h:136
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
static bool ApproxEqual(float a, float b, float relative_tolerance=0.001)
return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
Definition: kaldi-math.h:265