30 using namespace kaldi;
33 "Accumulate stats for SGMM training, given Gaussian-level posteriors\n" 34 "Usage: sgmm2-acc-stats-gpost [options] <model-in> <feature-rspecifier> " 35 "<gpost-rspecifier> <stats-out>\n" 36 "e.g.: sgmm2-acc-stats-gpost 1.mdl 1.ali scp:train.scp ark, s, cs:- 1.acc\n";
40 std::string spkvecs_rspecifier, utt2spk_rspecifier;
41 std::string update_flags_str =
"vMNwcSt";
44 po.Register(
"binary", &binary,
"Write output in binary mode");
45 po.Register(
"spk-vecs", &spkvecs_rspecifier,
"Speaker vectors (rspecifier)");
46 po.Register(
"utt2spk", &utt2spk_rspecifier,
47 "rspecifier for utterance to speaker map");
48 po.Register(
"rand-prune", &rand_prune,
"Pruning threshold for posteriors");
49 po.Register(
"update-flags", &update_flags_str,
"Which SGMM parameters to update: subset of vMNwcS.");
54 if (po.NumArgs() != 4) {
59 std::string model_filename = po.GetArg(1),
60 feature_rspecifier = po.GetArg(2),
61 gpost_rspecifier = po.GetArg(3),
62 accs_wxfilename = po.GetArg(4);
64 using namespace kaldi;
79 Input ki(model_filename, &binary);
80 trans_model.
Read(ki.Stream(), binary);
81 am_sgmm.
Read(ki.Stream(), binary);
87 sgmm_accs.ResizeAccumulators(am_sgmm, acc_flags, (spkvecs_rspecifier !=
""));
92 int32 num_done = 0, num_err = 0;
96 for (; !feature_reader.Done(); feature_reader.Next()) {
97 std::string utt = feature_reader.Key();
98 std::string spk = utt;
100 if (!utt2spk_rspecifier.empty()) {
101 if (!utt2spk_map.HasKey(utt)) {
102 KALDI_WARN <<
"utt2spk map does not have value for " << utt
103 <<
", ignoring this utterance.";
105 }
else { spk = utt2spk_map.Value(utt); }
108 if (spk != cur_spk && cur_spk !=
"")
109 sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
111 if (spk != cur_spk || spk_vars.
Empty()) {
113 if (spkvecs_reader.IsOpen()) {
114 if (spkvecs_reader.HasKey(utt)) {
118 KALDI_WARN <<
"Cannot find speaker vector for " << utt;
128 if (!gpost_reader.HasKey(utt) ||
129 gpost_reader.Value(utt).size() != mat.
NumRows()) {
130 KALDI_WARN <<
"No Gaussian-posterior information for utterance " 131 << utt <<
" (or wrong size).";
140 for (
size_t i = 0;
i < gpost.size();
i++) {
141 const std::vector<int32> &gselect = gpost[
i].gselect;
145 for (
size_t j = 0;
j < gpost[
i].tids.size();
j++) {
146 int32 tid = gpost[
i].tids[
j],
150 trans_model.
Accumulate(weight, tid, &transition_accs);
151 sgmm_accs.AccumulateFromPosteriors(am_sgmm, per_frame_vars,
152 gpost[
i].posteriors[
j],
154 tot_weight += weight;
159 if (num_done % 50 == 0)
160 KALDI_LOG <<
"Processed " << num_done <<
" utterances";
162 sgmm_accs.CommitStatsForSpk(am_sgmm, spk_vars);
164 KALDI_LOG <<
"Overall number of frames is " << tot_t;
165 KALDI_LOG <<
"Done " << num_done <<
" files, " 166 << num_err <<
" with errors.";
169 Output ko(accs_wxfilename, binary);
170 transition_accs.
Write(ko.Stream(), binary);
171 sgmm_accs.Write(ko.Stream(), binary);
174 return (num_done != 0 ? 0 : 1);
175 }
catch(
const std::exception &e) {
176 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Class for definition of the subspace Gmm acoustic model.
This class is for when you are reading something in random access, but it may actually be stored per-...
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void Read(std::istream &is, bool binary)
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
int32 TransitionIdToPdf(int32 trans_id) const
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
Allows random access to a collection of objects in an archive or script file; see The Table concept...
void InitStats(Vector< double > *stats) const
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Class for the accumulators associated with the phonetic-subspace model parameters.
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...