mle-am-diag-gmm.h
Go to the documentation of this file.
1 // gmm/mle-am-diag-gmm.h
2 
3 // Copyright 2009-2012 Saarland University (author: Arnab Ghoshal);
4 // Yanmin Qian; Johns Hopkins University (author: Daniel Povey)
5 // Cisco Systems (author: Neha Agrawal)
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 
23 #ifndef KALDI_GMM_MLE_AM_DIAG_GMM_H_
24 #define KALDI_GMM_MLE_AM_DIAG_GMM_H_ 1
25 
26 #include <vector>
27 
28 #include "gmm/am-diag-gmm.h"
29 #include "gmm/mle-diag-gmm.h"
30 #include "util/common-utils.h"
31 
32 namespace kaldi {
33 
35  public:
38 
39  void Read(std::istream &in_stream, bool binary, bool add = false);
40  void Write(std::ostream &out_stream, bool binary) const;
41 
44  void Init(const AmDiagGmm &model, GmmFlagsType flags);
46  void Init(const AmDiagGmm &model, int32 dim, GmmFlagsType flags);
47  void SetZero(GmmFlagsType flags);
48 
52  const VectorBase<BaseFloat> &data,
53  int32 gmm_index, BaseFloat weight);
54 
58  const VectorBase<BaseFloat> &data1,
59  const VectorBase<BaseFloat> &data2,
60  int32 gmm_index, BaseFloat weight);
61 
64  void AccumulateFromPosteriors(const AmDiagGmm &model,
65  const VectorBase<BaseFloat> &data,
66  int32 gmm_index,
67  const VectorBase<BaseFloat> &posteriors);
68 
70  void AccumulateForGaussian(const AmDiagGmm &am,
71  const VectorBase<BaseFloat> &data,
72  int32 gmm_index, int32 gauss_index,
73  BaseFloat weight);
74 
75  int32 NumAccs() { return gmm_accumulators_.size(); }
76 
77  int32 NumAccs() const { return gmm_accumulators_.size(); }
78 
79  BaseFloat TotStatsCount() const; // returns the total count got by summing the count
80  // of the actual stats, may differ from TotCount() if e.g. you did I-smoothing.
81 
82  // Be careful since total_frames_ is not updated in AccumulateForGaussian
83  BaseFloat TotCount() const { return total_frames_; }
84  BaseFloat TotLogLike() const { return total_log_like_; }
85 
86  const AccumDiagGmm& GetAcc(int32 index) const;
87 
88  AccumDiagGmm& GetAcc(int32 index);
89 
90  void Add(BaseFloat scale, const AccumAmDiagGmm &other);
91 
92  void Scale(BaseFloat scale);
93 
94  int32 Dim() const {
95  return (gmm_accumulators_.empty() || !gmm_accumulators_[0] ?
96  0 : gmm_accumulators_[0]->Dim());
97  }
98 
99  private:
101  std::vector<AccumDiagGmm*> gmm_accumulators_;
102 
105 
106  // Cannot have copy constructor and assigment operator
108 };
109 
112 void MleAmDiagGmmUpdate(const MleDiagGmmOptions &config,
113  const AccumAmDiagGmm &amdiaggmm_acc,
114  GmmFlagsType flags,
115  AmDiagGmm *am_gmm,
116  BaseFloat *obj_change_out,
117  BaseFloat *count_out);
118 
120 void MapAmDiagGmmUpdate(const MapDiagGmmOptions &config,
121  const AccumAmDiagGmm &diag_gmm_acc,
122  GmmFlagsType flags,
123  AmDiagGmm *gmm,
124  BaseFloat *obj_change_out,
125  BaseFloat *count_out);
126 
127 // These typedefs are needed to write GMMs to and from pipes, for MAP
128 // adaptation and decoding. Note: this doesn't handle the transition
129 // model, you have to read that in separately.
134 
135 } // End namespace kaldi
136 
137 
138 #endif // KALDI_GMM_MLE_AM_DIAG_GMM_H_
void MleAmDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
for computing the maximum-likelihood estimates of the parameters of an acoustic model that uses diago...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void MapAmDiagGmmUpdate(const MapDiagGmmOptions &config, const AccumAmDiagGmm &am_diag_gmm_acc, GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out, BaseFloat *count_out)
Maximum A Posteriori update.
This class is for when you are reading something in random access, but it may actually be stored per-...
Definition: kaldi-table.h:432
void SetZero(GmmFlagsType flags)
BaseFloat AccumulateForGmmTwofeats(const AmDiagGmm &model, const VectorBase< BaseFloat > &data1, const VectorBase< BaseFloat > &data2, int32 gmm_index, BaseFloat weight)
Accumulate stats for a single GMM in the model; uses data1 for getting posteriors and data2 for stats...
RandomAccessTableReaderMapped< KaldiObjectHolder< AmDiagGmm > > RandomAccessMapAmDiagGmmReaderMapped
BaseFloat AccumulateForGmm(const AmDiagGmm &model, const VectorBase< BaseFloat > &data, int32 gmm_index, BaseFloat weight)
Accumulate stats for a single GMM in the model; returns log likelihood.
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
SequentialTableReader< KaldiObjectHolder< AmDiagGmm > > MapAmDiagGmmSeqReader
BaseFloat TotCount() const
uint16 GmmFlagsType
Bitwise OR of the above flags.
Definition: model-common.h:35
void Add(BaseFloat scale, const AccumAmDiagGmm &other)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
void AccumulateFromPosteriors(const AmDiagGmm &model, const VectorBase< BaseFloat > &data, int32 gmm_index, const VectorBase< BaseFloat > &posteriors)
Accumulates stats for a single GMM in the model using pre-computed Gaussian posteriors.
double total_frames_
Total counts & likelihood (for diagnostics)
BaseFloat TotLogLike() const
void AccumulateForGaussian(const AmDiagGmm &am, const VectorBase< BaseFloat > &data, int32 gmm_index, int32 gauss_index, BaseFloat weight)
Accumulate stats for a single Gaussian component in the model.
TableWriter< KaldiObjectHolder< AmDiagGmm > > MapAmDiagGmmWriter
void Scale(BaseFloat scale)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
BaseFloat TotStatsCount() const
Configuration variables like variance floor, minimum occupancy, etc.
Definition: mle-diag-gmm.h:38
void Read(std::istream &in_stream, bool binary, bool add=false)
KALDI_DISALLOW_COPY_AND_ASSIGN(AccumAmDiagGmm)
RandomAccessTableReader< KaldiObjectHolder< AmDiagGmm > > RandomAccessMapAmDiagGmmReader
const AccumDiagGmm & GetAcc(int32 index) const
void Write(std::ostream &out_stream, bool binary) const
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void Init(const AmDiagGmm &model, GmmFlagsType flags)
Initializes accumulators for each GMM based on the number of components and dimension.
std::vector< AccumDiagGmm * > gmm_accumulators_
MLE accumulators and update methods for the GMMs.
Configuration variables for Maximum A Posteriori (MAP) update.
Definition: mle-diag-gmm.h:76