transform-feats.cc File Reference
Include dependency graph for transform-feats.cc:

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 26 of file transform-feats.cc.

References SpMatrix< Real >::AddMat2(), MatrixBase< Real >::AddMatMat(), kaldi::ClassifyRspecifier(), VectorBase< Real >::CopyColFromMat(), SequentialTableReader< Holder >::Done(), ParseOptions::GetArg(), RandomAccessTableReaderMapped< Holder >::HasKey(), KALDI_ERR, KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), kaldi::kNoRspecifier, kaldi::kNoTrans, kaldi::kTrans, SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumCols(), MatrixBase< Real >::NumRows(), RandomAccessTableReaderMapped< Holder >::Open(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), ParseOptions::Register(), SequentialTableReader< Holder >::Value(), RandomAccessTableReaderMapped< Holder >::Value(), and TableWriter< Holder >::Write().

26  {
27  try {
28  using namespace kaldi;
29 
30  const char *usage =
31  "Apply transform (e.g. LDA; HLDA; fMLLR/CMLLR; MLLT/STC)\n"
32  "Linear transform if transform-num-cols == feature-dim, affine if\n"
33  "transform-num-cols == feature-dim+1 (->append 1.0 to features)\n"
34  "Per-utterance by default, or per-speaker if utt2spk option provided\n"
35  "Global if transform-rxfilename provided.\n"
36  "Usage: transform-feats [options] (<transform-rspecifier>|<transform-rxfilename>) <feats-rspecifier> <feats-wspecifier>\n"
37  "See also: transform-vec, copy-feats, compose-transforms\n";
38 
39  ParseOptions po(usage);
40  std::string utt2spk_rspecifier;
41  po.Register("utt2spk", &utt2spk_rspecifier, "rspecifier for utterance to speaker map");
42 
43  po.Read(argc, argv);
44 
45  if (po.NumArgs() != 3) {
46  po.PrintUsage();
47  exit(1);
48  }
49 
50  std::string transform_rspecifier_or_rxfilename = po.GetArg(1);
51  std::string feat_rspecifier = po.GetArg(2);
52  std::string feat_wspecifier = po.GetArg(3);
53 
54  SequentialBaseFloatMatrixReader feat_reader(feat_rspecifier);
55  BaseFloatMatrixWriter feat_writer(feat_wspecifier);
56 
58  bool use_global_transform;
59  Matrix<BaseFloat> global_transform;
60  if (ClassifyRspecifier(transform_rspecifier_or_rxfilename, NULL, NULL)
61  == kNoRspecifier) {
62  // not an rspecifier -> interpret as rxfilename....
63  use_global_transform = true;
64  ReadKaldiObject(transform_rspecifier_or_rxfilename, &global_transform);
65  } else { // an rspecifier -> not a global transform.
66  use_global_transform = false;
67  if (!transform_reader.Open(transform_rspecifier_or_rxfilename,
68  utt2spk_rspecifier)) {
69  KALDI_ERR << "Problem opening transforms with rspecifier "
70  << '"' << transform_rspecifier_or_rxfilename << '"'
71  << " and utt2spk rspecifier "
72  << '"' << utt2spk_rspecifier << '"';
73  }
74  }
75 
76  enum { Unknown, Logdet, PseudoLogdet, DimIncrease };
77  int32 logdet_type = Unknown;
78  double tot_t = 0.0, tot_logdet = 0.0; // to compute average logdet weighted by time...
79  int32 num_done = 0, num_error = 0;
80  BaseFloat cached_logdet = -1;
81 
82  for (;!feat_reader.Done(); feat_reader.Next()) {
83  std::string utt = feat_reader.Key();
84  const Matrix<BaseFloat> &feat(feat_reader.Value());
85 
86  if (!use_global_transform && !transform_reader.HasKey(utt)) {
87  KALDI_WARN << "No fMLLR transform available for utterance "
88  << utt << ", producing no output for this utterance";
89  num_error++;
90  continue;
91  }
92  const Matrix<BaseFloat> &trans =
93  (use_global_transform ? global_transform : transform_reader.Value(utt));
94  int32 transform_rows = trans.NumRows(),
95  transform_cols = trans.NumCols(),
96  feat_dim = feat.NumCols();
97 
98  Matrix<BaseFloat> feat_out(feat.NumRows(), transform_rows);
99 
100  if (transform_cols == feat_dim) {
101  feat_out.AddMatMat(1.0, feat, kNoTrans, trans, kTrans, 0.0);
102  } else if (transform_cols == feat_dim + 1) {
103  // append the implicit 1.0 to the input features.
104  SubMatrix<BaseFloat> linear_part(trans, 0, transform_rows, 0, feat_dim);
105  feat_out.AddMatMat(1.0, feat, kNoTrans, linear_part, kTrans, 0.0);
106  Vector<BaseFloat> offset(transform_rows);
107  offset.CopyColFromMat(trans, feat_dim);
108  feat_out.AddVecToRows(1.0, offset);
109  } else {
110  KALDI_WARN << "Transform matrix for utterance " << utt << " has bad dimension "
111  << transform_rows << "x" << transform_cols << " versus feat dim "
112  << feat_dim;
113  if (transform_cols == feat_dim+2)
114  KALDI_WARN << "[perhaps the transform was created by compose-transforms, "
115  "and you forgot the --b-is-affine option?]";
116  num_error++;
117  continue;
118  }
119  num_done++;
120 
121  if (logdet_type == Unknown) {
122  if (transform_rows == feat_dim) logdet_type = Logdet; // actual logdet.
123  else if (transform_rows < feat_dim) logdet_type = PseudoLogdet; // see below
124  else logdet_type = DimIncrease; // makes no sense to have any logdet.
125  // PseudoLogdet is if we have a dimension-reducing transform T, we compute
126  // 1/2 logdet(T T^T). Why does this make sense? Imagine we do MLLT after
127  // LDA and compose the transforms; the MLLT matrix is A and the LDA matrix is L,
128  // so T = A L. T T^T = A L L^T A, so 1/2 logdet(T T^T) = logdet(A) + 1/2 logdet(L L^T).
129  // since L L^T is a constant, this is valid for comparing likelihoods if we're
130  // just trying to see if the MLLT is converging.
131  }
132 
133  if (logdet_type != DimIncrease) { // Accumulate log-determinant stats.
134  SubMatrix<BaseFloat> linear_transform(trans, 0, trans.NumRows(), 0, feat_dim);
135  // "linear_transform" is just the linear part of any transform, ignoring
136  // any affine (offset) component.
137  SpMatrix<BaseFloat> TT(trans.NumRows());
138  // TT = linear_transform * linear_transform^T
139  TT.AddMat2(1.0, linear_transform, kNoTrans, 0.0);
140  BaseFloat logdet;
141  if (use_global_transform) {
142  if (cached_logdet == -1)
143  cached_logdet = 0.5 * TT.LogDet(NULL);
144  logdet = cached_logdet;
145  } else {
146  logdet = 0.5 * TT.LogDet(NULL);
147  }
148  if (logdet != logdet || logdet-logdet != 0.0) // NaN or info.
149  KALDI_WARN << "Matrix has bad logdet " << logdet;
150  else {
151  tot_t += feat.NumRows();
152  tot_logdet += feat.NumRows() * logdet;
153  }
154  }
155  feat_writer.Write(utt, feat_out);
156  }
157  if (logdet_type != Unknown && logdet_type != DimIncrease)
158  KALDI_LOG << "Overall average " << (logdet_type == PseudoLogdet ? "[pseudo-]":"")
159  << "logdet is " << (tot_logdet/tot_t) << " over " << tot_t
160  << " frames.";
161  KALDI_LOG << "Applied transform to " << num_done << " utterances; " << num_error
162  << " had errors.";
163 
164  return (num_done != 0 ? 0 : 1);
165  } catch(const std::exception &e) {
166  std::cerr << e.what();
167  return -1;
168  }
169 }
void AddMat2(const Real alpha, const MatrixBase< Real > &M, MatrixTransposeType transM, const Real beta)
rank-N update: if (transM == kNoTrans) (*this) = beta*(*this) + alpha * M * M^T, or (if transM == kTr...
Definition: sp-matrix.cc:1110
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Packed symetric matrix class.
Definition: matrix-common.h:62
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
RspecifierType ClassifyRspecifier(const std::string &rspecifier, std::string *rxfilename, RspecifierOptions *opts)
Definition: kaldi-table.cc:225
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
bool Open(const std::string &table_rxfilename, const std::string &utt2spk_rxfilename)
Note: when calling Open, utt2spk_rxfilename may be empty.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void AddMatMat(const Real alpha, const MatrixBase< Real > &A, MatrixTransposeType transA, const MatrixBase< Real > &B, MatrixTransposeType transB, const Real beta)
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
A class representing a vector.
Definition: kaldi-vector.h:406
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
const T & Value(const std::string &key)
#define KALDI_LOG
Definition: kaldi-error.h:153
Sub-matrix representation.
Definition: kaldi-matrix.h:988