compute-cmvn-stats-two-channel.cc
Go to the documentation of this file.
1 // featbin/compute-cmvn-stats-two-channel.cc
2 
3 // Copyright 2013 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "matrix/kaldi-matrix.h"
23 #include "transform/cmvn.h"
24 
25 namespace kaldi {
26 
27 
28 /*
29  This function gets the utterances that are the first field of the
30  contents of the file reco2file_and_channel_rxfilename, and sorts
31  them into pairs corresponding to A/B sides, or singletons in case
32  we get one without the other.
33  */
34 void GetUtterancePairs(const std::string &reco2file_and_channel_rxfilename,
35  std::vector<std::vector<std::string> > *utt_pairs) {
36  Input ki(reco2file_and_channel_rxfilename);
37  std::string line;
38  std::map<std::string, std::vector<std::string> > call_to_uttlist;
39  while (std::getline(ki.Stream(), line)) {
40  std::vector<std::string> split_line;
41  SplitStringToVector(line, " \t\r", true, &split_line);
42  if (split_line.size() != 3) {
43  KALDI_ERR << "Expecting 3 fields per line of reco2file_and_channel file "
44  << PrintableRxfilename(reco2file_and_channel_rxfilename)
45  << ", got: " << line;
46  }
47  // lines like: sw02001-A sw02001 A
48  std::string utt = split_line[0],
49  call = split_line[1];
50  call_to_uttlist[call].push_back(utt);
51  }
52  for (std::map<std::string, std::vector<std::string> >::const_iterator
53  iter = call_to_uttlist.begin(); iter != call_to_uttlist.end(); ++iter) {
54  const std::vector<std::string> &uttlist = iter->second;
55  if (uttlist.size() == 2) {
56  utt_pairs->push_back(uttlist);
57  } else {
58  KALDI_WARN << "Call " << iter->first << " has " << uttlist.size()
59  << " utterances, expected two; treating them singly.";
60  for (size_t i = 0; i < uttlist.size(); i++) {
61  std::vector<std::string> singleton_list;
62  singleton_list.push_back(uttlist[i]);
63  utt_pairs->push_back(singleton_list);
64  }
65  }
66  }
67 }
68 
69 void AccCmvnStatsForPair(const std::string &utt1, const std::string &utt2,
70  const MatrixBase<BaseFloat> &feats1,
71  const MatrixBase<BaseFloat> &feats2,
72  BaseFloat quieter_channel_weight,
73  MatrixBase<double> *cmvn_stats1,
74  MatrixBase<double> *cmvn_stats2) {
75  KALDI_ASSERT(feats1.NumCols() == feats2.NumCols()); // same dim.
76  if (feats1.NumRows() != feats2.NumRows()) {
77  KALDI_WARN << "Number of frames differ between " << utt1 << " and " << utt2
78  << ": " << feats1.NumRows() << " vs. " << feats2.NumRows()
79  << ", treating them separately.";
80  AccCmvnStats(feats1, NULL, cmvn_stats1);
81  AccCmvnStats(feats2, NULL, cmvn_stats2);
82  return;
83  }
84 
85  for (int32 i = 0; i < feats1.NumRows(); i++) {
86  if (feats1(i, 0) > feats2(i, 0)) {
87  AccCmvnStats(feats1.Row(i), 1.0, cmvn_stats1);
88  AccCmvnStats(feats2.Row(i), quieter_channel_weight, cmvn_stats2);
89  }
90  else {
91  AccCmvnStats(feats2.Row(i), 1.0, cmvn_stats2);
92  AccCmvnStats(feats1.Row(i), quieter_channel_weight, cmvn_stats1);
93  }
94  }
95 }
96 
97 
98 }
99 
100 int main(int argc, char *argv[]) {
101  try {
102  using namespace kaldi;
103  using kaldi::int32;
104 
105  const char *usage =
106  "Compute cepstral mean and variance normalization statistics\n"
107  "Specialized for two-sided telephone data where we only accumulate\n"
108  "the louder of the two channels at each frame (and add it to that\n"
109  "side's stats). Reads a 'reco2file_and_channel' file, normally like\n"
110  "sw02001-A sw02001 A\n"
111  "sw02001-B sw02001 B\n"
112  "sw02005-A sw02005 A\n"
113  "sw02005-B sw02005 B\n"
114  "interpreted as <utterance-id> <call-id> <side> and for each <call-id>\n"
115  "that has two sides, does the 'only-the-louder' computation, else doesn\n"
116  "per-utterance stats in the normal way.\n"
117  "Note: loudness is judged by the first feature component, either energy or c0;\n"
118  "only applicable to MFCCs or PLPs (this code could be modified to handle filterbanks).\n"
119  "\n"
120  "Usage: compute-cmvn-stats-two-channel [options] <reco2file-and-channel> <feats-rspecifier> <stats-wspecifier>\n"
121  "e.g.: compute-cmvn-stats-two-channel data/train_unseg/reco2file_and_channel scp:data/train_unseg/feats.scp ark,t:-\n";
122 
123 
124  ParseOptions po(usage);
125  BaseFloat quieter_channel_weight = 0.01;
126 
127  po.Register("quieter-channel-weight", &quieter_channel_weight,
128  "For the quieter channel, apply this weight to the stats, so "
129  "that we still get stats if one channel always dominates.");
130 
131  po.Read(argc, argv);
132 
133  if (po.NumArgs() != 3) {
134  po.PrintUsage();
135  exit(1);
136  }
137 
138  int32 num_done = 0, num_err = 0;
139 
140  std::string reco2file_and_channel_rxfilename = po.GetArg(1),
141  feats_rspecifier = po.GetArg(2),
142  stats_wspecifier = po.GetArg(3);
143 
144 
145  std::vector<std::vector<std::string> > utt_pairs;
146  GetUtterancePairs(reco2file_and_channel_rxfilename, &utt_pairs);
147 
148  RandomAccessBaseFloatMatrixReader feat_reader(feats_rspecifier);
149  DoubleMatrixWriter writer(stats_wspecifier);
150 
151  for (size_t i = 0; i < utt_pairs.size(); i++) {
152  std::vector<std::string> this_pair(utt_pairs[i]);
153 
154  KALDI_ASSERT(this_pair.size() == 2 || this_pair.size() == 1);
155  if (this_pair.size() == 2) {
156  std::string utt1 = this_pair[0], utt2 = this_pair[1];
157  if (!feat_reader.HasKey(utt1)) {
158  KALDI_WARN << "No feature data for utterance " << utt1;
159  num_err++;
160  this_pair[0] = utt2;
161  this_pair.pop_back();
162  // and fall through to the singleton code below.
163  } else if (!feat_reader.HasKey(utt2)) {
164  KALDI_WARN << "No feature data for utterance " << utt2;
165  num_err++;
166  this_pair.pop_back();
167  // and fall through to the singleton code below.
168  } else {
169  Matrix<BaseFloat> feats1 = feat_reader.Value(utt1),
170  feats2 = feat_reader.Value(utt2);
171  int32 dim = feats1.NumCols();
172  Matrix<double> cmvn_stats1(2, dim + 1), cmvn_stats2(2, dim + 1);
173  AccCmvnStatsForPair(utt1, utt2, feats1, feats2, quieter_channel_weight,
174  &cmvn_stats1, &cmvn_stats2);
175  writer.Write(utt1, cmvn_stats1);
176  writer.Write(utt2, cmvn_stats2);
177  num_done += 2;
178  continue; // continue so we don't go to the singleton-processing code
179  // below.
180  }
181  }
182  // process singletons.
183  std::string utt = this_pair[0];
184  if (!feat_reader.HasKey(utt)) {
185  KALDI_WARN << "No feature data for utterance " << utt;
186  num_err++;
187  continue;
188  }
189  const Matrix<BaseFloat> &feats = feat_reader.Value(utt);
190  Matrix<double> cmvn_stats(2, feats.NumCols() + 1);
191  AccCmvnStats(feats, NULL, &cmvn_stats);
192  writer.Write(utt, cmvn_stats);
193  num_done++;
194  }
195  KALDI_LOG << "Done accumulating CMVN stats for " << num_done
196  << " utterances; " << num_err << " had errors.";
197  return (num_done != 0 ? 0 : 1);
198  } catch(const std::exception &e) {
199  std::cerr << e.what();
200  return -1;
201  }
202 }
203 
204 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void AccCmvnStatsForPair(const std::string &utt1, const std::string &utt2, const MatrixBase< BaseFloat > &feats1, const MatrixBase< BaseFloat > &feats2, BaseFloat quieter_channel_weight, MatrixBase< double > *cmvn_stats1, MatrixBase< double > *cmvn_stats2)
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
const T & Value(const std::string &key)
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
bool HasKey(const std::string &key)
void GetUtterancePairs(const std::string &reco2file_and_channel_rxfilename, std::vector< std::vector< std::string > > *utt_pairs)
void AccCmvnStats(const VectorBase< BaseFloat > &feats, BaseFloat weight, MatrixBase< double > *stats)
Accumulation from a single frame (weighted).
Definition: cmvn.cc:30
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
std::string PrintableRxfilename(const std::string &rxfilename)
PrintableRxfilename turns the rxfilename into a more human-readable form for error reporting...
Definition: kaldi-io.cc:61
#define KALDI_LOG
Definition: kaldi-error.h:153
int main(int argc, char *argv[])