All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
gmm-basis-fmllr-accs.cc File Reference
#include <string>
#include <vector>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/am-diag-gmm.h"
#include "hmm/transition-model.h"
#include "transform/fmllr-diag-gmm.h"
#include "transform/basis-fmllr-diag-gmm.h"
#include "hmm/posterior.h"
Include dependency graph for gmm-basis-fmllr-accs.cc:

Go to the source code of this file.

Namespaces

 kaldi
 Relabels neural network egs with the read pdf-id alignments.
 

Functions

void AccumulateForUtterance (const Matrix< BaseFloat > &feats, const Posterior &post, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats)
 
int main (int argc, char *argv[])
 

Function Documentation

int main ( int  argc,
char *  argv[] 
)

Definition at line 55 of file gmm-basis-fmllr-accs.cc.

References kaldi::AccumulateForUtterance(), AmDiagGmm::Dim(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReader< Holder >::HasKey(), rnnlm::i, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), AmDiagGmm::Read(), ParseOptions::Read(), TransitionModel::Read(), ParseOptions::Register(), Output::Stream(), Input::Stream(), RandomAccessTableReader< Holder >::Value(), and SequentialTableReader< Holder >::Value().

55  {
56  try {
57  typedef kaldi::int32 int32;
58  using namespace kaldi;
59  const char *usage =
60  "Accumulate gradient scatter from training set, either per utterance or \n"
61  "for the supplied set of speakers (spk2utt option). Reads posterior to accumulate \n"
62  "fMLLR stats for each speaker/utterance. Writes gradient scatter matrix.\n"
63  "Usage: gmm-basis-fmllr-accs [options] <model-in> <feature-rspecifier>"
64  "<post-rspecifier> <accs-wspecifier>\n";
65 
66  bool binary_write = true;
67  string spk2utt_rspecifier;
68  ParseOptions po(usage);
69  po.Register("binary", &binary_write, "Write output in binary mode");
70  po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to "
71  "utterance-list map");
72 
73  po.Read(argc, argv);
74  if (po.NumArgs() != 4) {
75  po.PrintUsage();
76  exit(1);
77  }
78 
79  string
80  model_rxfilename = po.GetArg(1),
81  feature_rspecifier = po.GetArg(2),
82  post_rspecifier = po.GetArg(3),
83  accs_wspecifier = po.GetArg(4);
84 
85  TransitionModel trans_model;
86  AmDiagGmm am_gmm;
87  {
88  bool binary;
89  Input ki(model_rxfilename, &binary);
90  trans_model.Read(ki.Stream(), binary);
91  am_gmm.Read(ki.Stream(), binary);
92  }
93 
94  RandomAccessPosteriorReader post_reader(post_rspecifier);
95  BasisFmllrAccus basis_accs(am_gmm.Dim());
96 
97  int32 num_done = 0, num_no_post = 0, num_other_error = 0;
98  if (spk2utt_rspecifier != "") { // per-speaker mode
99  SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
100  RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
101 
102  int32 num_spk = 0;
103  for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
104  FmllrDiagGmmAccs spk_stats(am_gmm.Dim());
105  string spk = spk2utt_reader.Key();
106  const vector<string> &uttlist = spk2utt_reader.Value();
107  for (size_t i = 0; i < uttlist.size(); i++) {
108  std::string utt = uttlist[i];
109  if (!feature_reader.HasKey(utt)) {
110  KALDI_WARN << "Did not find features for utterance " << utt;
111  num_other_error++;
112  continue;
113  }
114  if (!post_reader.HasKey(utt)) {
115  KALDI_WARN << "Did not find posteriors for utterance " << utt;
116  num_no_post++;
117  continue;
118  }
119  const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
120  const Posterior &post = post_reader.Value(utt);
121  if (static_cast<int32>(post.size()) != feats.NumRows()) {
122  KALDI_WARN << "Posterior vector has wrong size " << (post.size())
123  << " vs. " << (feats.NumRows());
124  num_other_error++;
125  continue;
126  }
127 
128  AccumulateForUtterance(feats, post, trans_model, am_gmm, &spk_stats);
129 
130  num_done++;
131  } // end looping over all utterances of this speaker
132  basis_accs.AccuGradientScatter(spk_stats);
133  num_spk++;
134  } // end looping over speakers
135  KALDI_LOG << "Accumulate statistics from " << num_spk << " speakers";
136 
137  } else { // per-utterance mode
138  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
139  for (; !feature_reader.Done(); feature_reader.Next()) {
140  string utt = feature_reader.Key();
141  if (!post_reader.HasKey(utt)) {
142  KALDI_WARN << "Did not find posts for utterance "
143  << utt;
144  num_no_post++;
145  continue;
146  }
147  const Matrix<BaseFloat> &feats = feature_reader.Value();
148  const Posterior &post = post_reader.Value(utt);
149 
150  if (static_cast<int32>(post.size()) != feats.NumRows()) {
151  KALDI_WARN << "Posterior has wrong size " << (post.size())
152  << " vs. " << (feats.NumRows());
153  num_other_error++;
154  continue;
155  }
156  // Accumulate stats for this utterance
157  FmllrDiagGmmAccs utt_stats(am_gmm.Dim());
158  AccumulateForUtterance(feats, post, trans_model, am_gmm, &utt_stats);
159  num_done++;
160 
161  basis_accs.AccuGradientScatter(utt_stats);
162  } // end looping over utterances
163  }
164  // Write out accumulations
165  {
166  Output ko(accs_wspecifier, binary_write);
167  basis_accs.Write(ko.Stream(), binary_write);
168  }
169  KALDI_LOG << "Done " << num_done << " files, " << num_no_post
170  << " with no posts, " << num_other_error << " with other errors.";
171  KALDI_LOG << "Written gradient scatter to " << accs_wspecifier;
172  return (num_done != 0 ? 0 : 1);
173  } catch(const std::exception& e) {
174  std::cerr << e.what();
175  return -1;
176  }
177 }
Relabels neural network egs with the read pdf-id alignments.
Definition: chain.dox:20
This does not work with multiple feature transforms.
void AccumulateForUtterance(const Matrix< BaseFloat > &feats, const GaussPost &gpost, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, FmllrDiagGmmAccs *spk_stats)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:43
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
Stats for fMLLR subspace estimation.
int32 Dim() const
Definition: am-diag-gmm.h:79
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:130
MatrixIndexT NumRows() const
Returns number of rows (or zero for emtpy matrix).
Definition: kaldi-matrix.h:58
#define KALDI_LOG
Definition: kaldi-error.h:133
void Read(std::istream &in_stream, bool binary)
Definition: am-diag-gmm.cc:147