29 int main(
int argc,
char *argv[]) {
30 using namespace kaldi;
33 "Accumulate numerator and denominator stats for discriminative training\n" 34 "of SGMMs (input is posteriors of mixed sign)\n" 35 "Usage: sgmm2-acc-stats2 [options] <model-in> <feature-rspecifier> " 36 "<posteriors-rspecifier> <num-stats-out> <den-stats-out>\n" 37 "e.g.: sgmm2-acc-stats2 1.mdl 1.ali scp:train.scp ark:1.posts num.acc den.acc\n";
41 std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
42 std::string update_flags_str =
"vMNwucSt";
45 po.
Register(
"binary", &binary,
"Write output in binary mode");
46 po.
Register(
"gselect", &gselect_rspecifier,
"Precomputed Gaussian indices (rspecifier)");
47 po.
Register(
"spk-vecs", &spkvecs_rspecifier,
"Speaker vectors (rspecifier)");
48 po.
Register(
"utt2spk", &utt2spk_rspecifier,
49 "rspecifier for utterance to speaker map");
50 po.
Register(
"rand-prune", &rand_prune,
"Pruning threshold for posteriors");
51 po.
Register(
"update-flags", &update_flags_str,
"Which SGMM parameters to accumulate " 52 "stats for: subset of vMNwcS.");
63 std::string model_filename = po.
GetArg(1),
64 feature_rspecifier = po.
GetArg(2),
65 posteriors_rspecifier = po.
GetArg(3),
66 num_accs_wxfilename = po.
GetArg(4),
67 den_accs_wxfilename = po.
GetArg(5);
70 using namespace kaldi;
72 typedef kaldi::int64 int64;
88 Input ki(model_filename, &binary);
95 KALDI_WARN <<
"Removing speaker weight projections (u) from flags " 96 "as not present in model\n";
100 KALDI_WARN <<
"Removing speaker projections (N) from flags " 101 "as not present in model\n";
106 trans_model.
InitStats(&num_transition_accs);
107 trans_model.
InitStats(&den_transition_accs);
109 MleAmSgmm2Accs num_sgmm_accs(rand_prune), den_sgmm_accs(rand_prune);
110 bool have_spk_vecs = (spkvecs_rspecifier !=
"");
111 num_sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, have_spk_vecs);
114 double tot_like = 0.0, tot_weight = 0.0, tot_abs_weight = 0.0;
115 int64 tot_frames = 0;
119 int32 num_done = 0, num_err = 0;
123 for (; !feature_reader.
Done(); feature_reader.
Next()) {
124 std::string utt = feature_reader.
Key();
125 std::string spk = utt;
126 if (!utt2spk_rspecifier.empty()) {
127 if (!utt2spk_map.
HasKey(utt)) {
128 KALDI_WARN <<
"utt2spk map does not have value for " << utt
129 <<
", ignoring this utterance.";
131 }
else { spk = utt2spk_map.
Value(utt); }
133 if (spk != cur_spk && cur_spk !=
"") {
134 num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
137 if (spk != cur_spk || spk_vars.
Empty()) {
139 if (spkvecs_reader.
IsOpen()) {
140 if (spkvecs_reader.
HasKey(utt)) {
144 KALDI_WARN <<
"Cannot find speaker vector for " << utt;
153 if (!posteriors_reader.
HasKey(utt) ||
154 posteriors_reader.
Value(utt).size() != features.
NumRows()) {
155 KALDI_WARN <<
"No posterior info available for utterance " 156 << utt <<
" (or wrong size)";
162 if (!gselect_reader.
HasKey(utt)
163 && gselect_reader.
Value(utt).size() != features.
NumRows()) {
164 KALDI_WARN <<
"No Gaussian-selection info available for utterance " 165 << utt <<
" (or wrong size)";
168 const std::vector<std::vector<int32> > &gselect =
169 gselect_reader.
Value(utt);
172 BaseFloat tot_like_this_file = 0.0, tot_weight_this_file = 0.0,
173 tot_abs_weight_this_file = 0.0;
175 for (
size_t i = 0;
i < posterior.size();
i++) {
176 if (posterior[
i].empty())
181 for (
size_t j = 0;
j < posterior[
i].size();
j++) {
182 int32 tid = posterior[
i][
j].first,
185 abs_weight = std::abs(weight);
187 if (acc_flags & kaldi::kSgmmTransitions) {
188 trans_model.
Accumulate(abs_weight, tid, weight > 0 ?
189 &num_transition_accs : &den_transition_accs);
191 tot_like_this_file +=
192 (weight > 0 ? num_sgmm_accs : den_sgmm_accs).Accumulate(
193 am_sgmm, per_frame_vars, pdf_id, abs_weight, &spk_vars)
195 tot_weight_this_file += weight;
196 tot_abs_weight_this_file += abs_weight;
200 num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
204 tot_like += tot_like_this_file;
205 tot_weight += tot_weight_this_file;
206 tot_abs_weight += tot_abs_weight_this_file;
207 tot_frames += posterior.size();
208 if (num_done % 50 == 0)
209 KALDI_LOG <<
"Processed " << num_done <<
" utterances.";
212 num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
215 KALDI_LOG <<
"Overall weighted acoustic likelihood per frame was " 216 << (tot_like/tot_frames) <<
" over " << tot_frames <<
" frames; " 217 <<
"average weight per frame is " << (tot_weight/tot_frames)
218 <<
", average abs(weight) per frame is " 219 << (tot_abs_weight/tot_frames);
221 KALDI_LOG <<
"Done " << num_done <<
" files, " << num_err
225 Output ko(num_accs_wxfilename, binary);
227 num_sgmm_accs.Write(ko.
Stream(), binary);
230 Output ko(den_accs_wxfilename, binary);
235 return (num_done != 0 ? 0 : 1);
236 }
catch(
const std::exception &e) {
237 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Class for definition of the subspace Gmm acoustic model.
bool HasSpeakerSpace() const
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This class is for when you are reading something in random access, but it may actually be stored per-...
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void Read(std::istream &is, bool binary)
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
int32 TransitionIdToPdf(int32 trans_id) const
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
void CommitStatsForSpk(const AmSgmm2 &model, const Sgmm2PerSpkDerivedVars &spk_vars)
Accumulates global stats for the current speaker (if applicable).
t .. not really part of SGMM.
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
void InitStats(Vector< double > *stats) const
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
void Write(std::ostream &out_stream, bool binary) const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
const T & Value(const std::string &key)
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
bool HasSpeakerDependentWeights() const
True if doing SSGMM.
Class for the accumulators associated with the phonetic-subspace model parameters.
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...
int main(int argc, char *argv[])
void ResizeAccumulators(const AmSgmm2 &model, SgmmUpdateFlagsType flags, bool have_spk_vecs)
Resizes the accumulators to the correct sizes given the model.