30 using namespace kaldi;
33 "Accumulate numerator and denominator stats for discriminative training\n" 34 "of SGMMs (input is posteriors of mixed sign)\n" 35 "Usage: sgmm2-acc-stats2 [options] <model-in> <feature-rspecifier> " 36 "<posteriors-rspecifier> <num-stats-out> <den-stats-out>\n" 37 "e.g.: sgmm2-acc-stats2 1.mdl 1.ali scp:train.scp ark:1.posts num.acc den.acc\n";
41 std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
42 std::string update_flags_str =
"vMNwucSt";
45 po.Register(
"binary", &binary,
"Write output in binary mode");
46 po.Register(
"gselect", &gselect_rspecifier,
"Precomputed Gaussian indices (rspecifier)");
47 po.Register(
"spk-vecs", &spkvecs_rspecifier,
"Speaker vectors (rspecifier)");
48 po.Register(
"utt2spk", &utt2spk_rspecifier,
49 "rspecifier for utterance to speaker map");
50 po.Register(
"rand-prune", &rand_prune,
"Pruning threshold for posteriors");
51 po.Register(
"update-flags", &update_flags_str,
"Which SGMM parameters to accumulate " 52 "stats for: subset of vMNwcS.");
58 if (po.NumArgs() != 5) {
63 std::string model_filename = po.GetArg(1),
64 feature_rspecifier = po.GetArg(2),
65 posteriors_rspecifier = po.GetArg(3),
66 num_accs_wxfilename = po.GetArg(4),
67 den_accs_wxfilename = po.GetArg(5);
70 using namespace kaldi;
72 typedef kaldi::int64 int64;
88 Input ki(model_filename, &binary);
89 trans_model.
Read(ki.Stream(), binary);
90 am_sgmm.
Read(ki.Stream(), binary);
95 KALDI_WARN <<
"Removing speaker weight projections (u) from flags " 96 "as not present in model\n";
100 KALDI_WARN <<
"Removing speaker projections (N) from flags " 101 "as not present in model\n";
106 trans_model.
InitStats(&num_transition_accs);
107 trans_model.
InitStats(&den_transition_accs);
109 MleAmSgmm2Accs num_sgmm_accs(rand_prune), den_sgmm_accs(rand_prune);
110 bool have_spk_vecs = (spkvecs_rspecifier !=
"");
111 num_sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, have_spk_vecs);
112 den_sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, have_spk_vecs);
114 double tot_like = 0.0, tot_weight = 0.0, tot_abs_weight = 0.0;
115 int64 tot_frames = 0;
119 int32 num_done = 0, num_err = 0;
123 for (; !feature_reader.Done(); feature_reader.Next()) {
124 std::string utt = feature_reader.Key();
125 std::string spk = utt;
126 if (!utt2spk_rspecifier.empty()) {
127 if (!utt2spk_map.HasKey(utt)) {
128 KALDI_WARN <<
"utt2spk map does not have value for " << utt
129 <<
", ignoring this utterance.";
131 }
else { spk = utt2spk_map.Value(utt); }
133 if (spk != cur_spk && cur_spk !=
"") {
134 num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
135 den_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
137 if (spk != cur_spk || spk_vars.
Empty()) {
139 if (spkvecs_reader.IsOpen()) {
140 if (spkvecs_reader.HasKey(utt)) {
144 KALDI_WARN <<
"Cannot find speaker vector for " << utt;
153 if (!posteriors_reader.HasKey(utt) ||
154 posteriors_reader.Value(utt).size() != features.
NumRows()) {
155 KALDI_WARN <<
"No posterior info available for utterance " 156 << utt <<
" (or wrong size)";
161 const Posterior &posterior = posteriors_reader.Value(utt);
162 if (!gselect_reader.HasKey(utt)
163 && gselect_reader.Value(utt).size() != features.
NumRows()) {
164 KALDI_WARN <<
"No Gaussian-selection info available for utterance " 165 << utt <<
" (or wrong size)";
168 const std::vector<std::vector<int32> > &gselect =
169 gselect_reader.Value(utt);
172 BaseFloat tot_like_this_file = 0.0, tot_weight_this_file = 0.0,
173 tot_abs_weight_this_file = 0.0;
175 for (
size_t i = 0;
i < posterior.size();
i++) {
176 if (posterior[
i].empty())
181 for (
size_t j = 0;
j < posterior[
i].size();
j++) {
182 int32 tid = posterior[
i][
j].first,
185 abs_weight = std::abs(weight);
187 if (acc_flags & kaldi::kSgmmTransitions) {
188 trans_model.
Accumulate(abs_weight, tid, weight > 0 ?
189 &num_transition_accs : &den_transition_accs);
191 tot_like_this_file +=
192 (weight > 0 ? num_sgmm_accs : den_sgmm_accs).Accumulate(
193 am_sgmm, per_frame_vars, pdf_id, abs_weight, &spk_vars)
195 tot_weight_this_file += weight;
196 tot_abs_weight_this_file += abs_weight;
200 num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
201 den_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
204 tot_like += tot_like_this_file;
205 tot_weight += tot_weight_this_file;
206 tot_abs_weight += tot_abs_weight_this_file;
207 tot_frames += posterior.size();
208 if (num_done % 50 == 0)
209 KALDI_LOG <<
"Processed " << num_done <<
" utterances.";
212 num_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
213 den_sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
215 KALDI_LOG <<
"Overall weighted acoustic likelihood per frame was " 216 << (tot_like/tot_frames) <<
" over " << tot_frames <<
" frames; " 217 <<
"average weight per frame is " << (tot_weight/tot_frames)
218 <<
", average abs(weight) per frame is " 219 << (tot_abs_weight/tot_frames);
221 KALDI_LOG <<
"Done " << num_done <<
" files, " << num_err
225 Output ko(num_accs_wxfilename, binary);
226 num_transition_accs.
Write(ko.Stream(), binary);
227 num_sgmm_accs.Write(ko.Stream(), binary);
230 Output ko(den_accs_wxfilename, binary);
231 den_transition_accs.
Write(ko.Stream(), binary);
232 den_sgmm_accs.Write(ko.Stream(), binary);
235 return (num_done != 0 ? 0 : 1);
236 }
catch(
const std::exception &e) {
237 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Class for definition of the subspace Gmm acoustic model.
bool HasSpeakerSpace() const
This class is for when you are reading something in random access, but it may actually be stored per-...
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void Read(std::istream &is, bool binary)
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
int32 TransitionIdToPdf(int32 trans_id) const
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
t .. not really part of SGMM.
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
void InitStats(Vector< double > *stats) const
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
bool HasSpeakerDependentWeights() const
True if doing SSGMM.
Class for the accumulators associated with the phonetic-subspace model parameters.
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...