PushSpecialClass Class Reference
Collaboration diagram for PushSpecialClass:

Public Member Functions

 PushSpecialClass (VectorFst< StdArc > *fst, float delta)
 

Private Types

typedef StdArc Arc
 
typedef Arc::Weight Weight
 
typedef Arc::StateId StateId
 

Private Member Functions

double TestAccuracy ()
 
void Iterate (float delta)
 
void ModifyFst ()
 

Private Attributes

StateId num_states_
 
StateId initial_state_
 
std::vector< double > occ_
 
double lambda_
 
std::vector< std::vector< std::pair< StateId, double > > > pred_
 
VectorFst< StdArc > * fst_
 

Detailed Description

Definition at line 86 of file push-special.cc.

Member Typedef Documentation

◆ Arc

typedef StdArc Arc
private

Definition at line 87 of file push-special.cc.

◆ StateId

typedef Arc::StateId StateId
private

Definition at line 89 of file push-special.cc.

◆ Weight

typedef Arc::Weight Weight
private

Definition at line 88 of file push-special.cc.

Constructor & Destructor Documentation

◆ PushSpecialClass()

PushSpecialClass ( VectorFst< StdArc > *  fst,
float  delta 
)
inline

Definition at line 93 of file push-special.cc.

References kaldi::Exp(), PushSpecialClass::fst_, PushSpecialClass::initial_state_, PushSpecialClass::Iterate(), PushSpecialClass::ModifyFst(), PushSpecialClass::num_states_, PushSpecialClass::occ_, and PushSpecialClass::pred_.

94  : fst_(fst) {
95  num_states_ = fst_->NumStates();
96  initial_state_ = fst_->Start();
97  occ_.resize(num_states_, 1.0 / sqrt(num_states_)); // unit length
98 
99  pred_.resize(num_states_);
100  for (StateId s = 0; s < num_states_; s++) {
101  for (ArcIterator<VectorFst<StdArc> > aiter(*fst, s);
102  !aiter.Done(); aiter.Next()) {
103  const Arc &arc = aiter.Value();
104  StateId t = arc.nextstate;
105  double weight = kaldi::Exp(-arc.weight.Value());
106  pred_[t].push_back(std::make_pair(s, weight));
107  }
108  double final = kaldi::Exp(-fst_->Final(s).Value());
109  if (final != 0.0)
110  pred_[initial_state_].push_back(std::make_pair(s, final));
111  }
112  Iterate(delta);
113  ModifyFst();
114  }
double Exp(double x)
Definition: kaldi-math.h:83
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
Definition: graph.dox:21
VectorFst< StdArc > * fst_
std::vector< double > occ_
void Iterate(float delta)
std::vector< std::vector< std::pair< StateId, double > > > pred_
Arc::StateId StateId
Definition: push-special.cc:89

Member Function Documentation

◆ Iterate()

void Iterate ( float  delta)
inlineprivate

Definition at line 143 of file push-special.cc.

References rnnlm::i, rnnlm::j, KALDI_VLOG, KALDI_WARN, PushSpecialClass::lambda_, PushSpecialClass::num_states_, PushSpecialClass::occ_, PushSpecialClass::pred_, and PushSpecialClass::TestAccuracy().

Referenced by PushSpecialClass::PushSpecialClass().

143  {
144  // This is like the power method to find the top eigenvalue of a matrix.
145  // We limit it to 200 iters max, just in case something unanticipated
146  // happens, but we should exit due to the "delta" thing, usually after
147  // several tens of iterations.
148  int iter, max_iter = 200;
149 
150  for (iter = 0; iter < max_iter; iter++) {
151  std::vector<double> new_occ(num_states_);
152  // We initialize new_occ to 0.1 * occ. A simpler algorithm would
153  // initialize them to zero, so it's like the pure power method. This is
154  // like the power method on (M + 0.1 I), and we do it this way to avoid a
155  // problem we encountered with certain very simple linear FSTs where the
156  // eigenvalues of the weight matrix (including negative and imaginary
157  // ones) all have the same magnitude.
158  for (int i = 0; i < num_states_; i++)
159  new_occ[i] = 0.1 * occ_[i];
160 
161  for (int i = 0; i < num_states_; i++) {
162  std::vector<std::pair<StateId, double> >::const_iterator iter,
163  end = pred_[i].end();
164  for (iter = pred_[i].begin(); iter != end; ++iter) {
165  StateId j = iter->first;
166  double p = iter->second;
167  new_occ[j] += occ_[i] * p;
168  }
169  }
170  double sumsq = 0.0;
171  for (int i = 0; i < num_states_; i++) sumsq += new_occ[i] * new_occ[i];
172  lambda_ = std::sqrt(sumsq);
173  double inv_lambda = 1.0 / lambda_;
174  for (int i = 0; i < num_states_; i++) occ_[i] = new_occ[i] * inv_lambda;
175  KALDI_VLOG(4) << "Lambda is " << lambda_;
176  if (iter % 5 == 0 && iter > 0 && TestAccuracy() <= delta) {
177  KALDI_VLOG(3) << "Weight-pushing converged after " << iter
178  << " iterations.";
179  return;
180  }
181  }
182  KALDI_WARN << "push-special: finished " << iter
183  << " iterations without converging. Output will be inaccurate.";
184  }
std::vector< double > occ_
#define KALDI_WARN
Definition: kaldi-error.h:150
std::vector< std::vector< std::pair< StateId, double > > > pred_
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Arc::StateId StateId
Definition: push-special.cc:89

◆ ModifyFst()

void ModifyFst ( )
inlineprivate

Definition at line 188 of file push-special.cc.

References PushSpecialClass::fst_, PushSpecialClass::initial_state_, KALDI_ISINF, KALDI_ISNAN, KALDI_WARN, kaldi::Log(), PushSpecialClass::num_states_, PushSpecialClass::occ_, and fst::Times().

Referenced by PushSpecialClass::PushSpecialClass().

188  {
189  // First get the potentials as negative-logs, like the values
190  // in the FST.
191  for (StateId s = 0; s < num_states_; s++) {
192  occ_[s] = -kaldi::Log(occ_[s]);
193  if (KALDI_ISNAN(occ_[s]) || KALDI_ISINF(occ_[s]))
194  KALDI_WARN << "NaN or inf found: " << occ_[s];
195  }
196  for (StateId s = 0; s < num_states_; s++) {
197  for (MutableArcIterator<VectorFst<StdArc> > aiter(fst_, s);
198  !aiter.Done(); aiter.Next()) {
199  Arc arc = aiter.Value();
200  StateId t = arc.nextstate;
201  arc.weight = Weight(arc.weight.Value() + occ_[t] - occ_[s]);
202  aiter.SetValue(arc);
203  }
204  fst_->SetFinal(s, Times(fst_->Final(s).Value(),
205  Weight(occ_[initial_state_] - occ_[s])));
206  }
207  }
#define KALDI_ISINF
Definition: kaldi-math.h:73
VectorFst< StdArc > * fst_
LatticeWeightTpl< FloatType > Times(const LatticeWeightTpl< FloatType > &w1, const LatticeWeightTpl< FloatType > &w2)
std::vector< double > occ_
double Log(double x)
Definition: kaldi-math.h:100
#define KALDI_WARN
Definition: kaldi-error.h:150
#define KALDI_ISNAN
Definition: kaldi-math.h:72
Arc::StateId StateId
Definition: push-special.cc:89

◆ TestAccuracy()

double TestAccuracy ( )
inlineprivate

Definition at line 116 of file push-special.cc.

References kaldi::Exp(), PushSpecialClass::fst_, PushSpecialClass::initial_state_, KALDI_VLOG, kaldi::Log(), PushSpecialClass::num_states_, and PushSpecialClass::occ_.

Referenced by PushSpecialClass::Iterate().

116  { // returns the error (the difference
117  // between the min and max weights).
118  double min_sum = 0, max_sum = 0;
119  for (StateId s = 0; s < num_states_; s++) {
120  double sum = 0.0;
121  for (ArcIterator<VectorFst<StdArc> > aiter(*fst_, s);
122  !aiter.Done(); aiter.Next()) {
123  const Arc &arc = aiter.Value();
124  StateId t = arc.nextstate;
125  sum += kaldi::Exp(-arc.weight.Value()) * occ_[t] / occ_[s];
126  }
127  sum += kaldi::Exp(-(fst_->Final(s).Value())) * occ_[initial_state_] / occ_[s];
128  if (s == 0) {
129  min_sum = sum;
130  max_sum = sum;
131  } else {
132  min_sum = std::min(min_sum, sum);
133  max_sum = std::max(max_sum, sum);
134  }
135  }
136  KALDI_VLOG(4) << "min,max is " << min_sum << " " << max_sum;
137  return kaldi::Log(max_sum / min_sum); // In FST world we'll actually
138  // dealing with logs, so the log of the ratio is more suitable
139  // to compare with delta (makes testing the algorithm easier).
140  }
double Exp(double x)
Definition: kaldi-math.h:83
VectorFst< StdArc > * fst_
std::vector< double > occ_
double Log(double x)
Definition: kaldi-math.h:100
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
Arc::StateId StateId
Definition: push-special.cc:89

Member Data Documentation

◆ fst_

VectorFst<StdArc>* fst_
private

◆ initial_state_

◆ lambda_

double lambda_
private

Definition at line 213 of file push-special.cc.

Referenced by PushSpecialClass::Iterate().

◆ num_states_

◆ occ_

◆ pred_

std::vector<std::vector<std::pair<StateId, double> > > pred_
private

The documentation for this class was generated from the following file: