29 int main(
int argc,
char *argv[]) {
30 using namespace kaldi;
33 "Accumulate stats for SGMM training, given Gaussian-level posteriors\n" 34 "Usage: sgmm2-acc-stats-gpost [options] <model-in> <feature-rspecifier> " 35 "<gpost-rspecifier> <stats-out>\n" 36 "e.g.: sgmm2-acc-stats-gpost 1.mdl 1.ali scp:train.scp ark, s, cs:- 1.acc\n";
40 std::string spkvecs_rspecifier, utt2spk_rspecifier;
41 std::string update_flags_str =
"vMNwcSt";
44 po.
Register(
"binary", &binary,
"Write output in binary mode");
45 po.
Register(
"spk-vecs", &spkvecs_rspecifier,
"Speaker vectors (rspecifier)");
46 po.
Register(
"utt2spk", &utt2spk_rspecifier,
47 "rspecifier for utterance to speaker map");
48 po.
Register(
"rand-prune", &rand_prune,
"Pruning threshold for posteriors");
49 po.
Register(
"update-flags", &update_flags_str,
"Which SGMM parameters to update: subset of vMNwcS.");
59 std::string model_filename = po.
GetArg(1),
60 feature_rspecifier = po.
GetArg(2),
61 gpost_rspecifier = po.
GetArg(3),
62 accs_wxfilename = po.
GetArg(4);
64 using namespace kaldi;
79 Input ki(model_filename, &binary);
92 int32 num_done = 0, num_err = 0;
96 for (; !feature_reader.
Done(); feature_reader.
Next()) {
97 std::string utt = feature_reader.
Key();
98 std::string spk = utt;
100 if (!utt2spk_rspecifier.empty()) {
101 if (!utt2spk_map.
HasKey(utt)) {
102 KALDI_WARN <<
"utt2spk map does not have value for " << utt
103 <<
", ignoring this utterance.";
105 }
else { spk = utt2spk_map.
Value(utt); }
108 if (spk != cur_spk && cur_spk !=
"")
111 if (spk != cur_spk || spk_vars.
Empty()) {
113 if (spkvecs_reader.
IsOpen()) {
114 if (spkvecs_reader.
HasKey(utt)) {
118 KALDI_WARN <<
"Cannot find speaker vector for " << utt;
128 if (!gpost_reader.
HasKey(utt) ||
130 KALDI_WARN <<
"No Gaussian-posterior information for utterance " 131 << utt <<
" (or wrong size).";
140 for (
size_t i = 0;
i < gpost.size();
i++) {
141 const std::vector<int32> &gselect = gpost[
i].gselect;
145 for (
size_t j = 0;
j < gpost[
i].tids.size();
j++) {
146 int32 tid = gpost[
i].tids[
j],
150 trans_model.
Accumulate(weight, tid, &transition_accs);
152 gpost[
i].posteriors[
j],
154 tot_weight += weight;
159 if (num_done % 50 == 0)
160 KALDI_LOG <<
"Processed " << num_done <<
" utterances";
164 KALDI_LOG <<
"Overall number of frames is " << tot_t;
165 KALDI_LOG <<
"Done " << num_done <<
" files, " 166 << num_err <<
" with errors.";
169 Output ko(accs_wxfilename, binary);
174 return (num_done != 0 ? 0 : 1);
175 }
catch(
const std::exception &e) {
176 std::cerr << e.what();
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Class for definition of the subspace Gmm acoustic model.
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
This class is for when you are reading something in random access, but it may actually be stored per-...
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void Read(std::istream &is, bool binary)
SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str)
int32 TransitionIdToPdf(int32 trans_id) const
void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const
Computes the per-speaker derived vars; assumes vars->v_s is already set up.
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
void CommitStatsForSpk(const AmSgmm2 &model, const Sgmm2PerSpkDerivedVars &spk_vars)
Accumulates global stats for the current speaker (if applicable).
void InitStats(Vector< double > *stats) const
uint16 SgmmUpdateFlagsType
Bitwise OR of the above flags.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
const T & Value(const std::string &key)
void Read(std::istream &is, bool binary)
void Accumulate(BaseFloat prob, int32 trans_id, Vector< double > *stats) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void ComputePerFrameVars(const VectorBase< BaseFloat > &data, const std::vector< int32 > &gselect, const Sgmm2PerSpkDerivedVars &spk_vars, Sgmm2PerFrameDerivedVars *per_frame_vars) const
This needs to be called with each new frame of data, prior to accumulation or likelihood evaluation: ...
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
BaseFloat AccumulateFromPosteriors(const AmSgmm2 &model, const Sgmm2PerFrameDerivedVars &frame_vars, const Matrix< BaseFloat > &posteriors, int32 pdf_index, Sgmm2PerSpkDerivedVars *spk_vars)
Returns count accumulated (may differ from posteriors.Sum() due to weight pruning).
void Write(std::ostream &out_stream, bool binary) const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
int main(int argc, char *argv[])
const T & Value(const std::string &key)
void SetSpeakerVector(const Vector< BaseFloat > &v_s_in)
Class for the accumulators associated with the phonetic-subspace model parameters.
Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and n_{i}(t) (cf...
void ResizeAccumulators(const AmSgmm2 &model, SgmmUpdateFlagsType flags, bool have_spk_vecs)
Resizes the accumulators to the correct sizes given the model.