25 namespace discriminative {
30 weight(other.weight), num_sequences(other.num_sequences),
31 frames_per_sequence(other.frames_per_sequence),
32 num_ali(other.num_ali), den_lat(other.den_lat) { }
52 WriteToken(os, binary,
"<DiscriminativeSupervision>");
69 KALDI_ERR <<
"Error writing denominator lattice to stream";
72 WriteToken(os, binary,
"</DiscriminativeSupervision>");
76 ExpectToken(is, binary,
"<DiscriminativeSupervision>");
92 if (!
ReadLattice(is, binary, &lat) || lat == NULL) {
95 KALDI_ERR <<
"Error reading Lattice from stream";
102 ExpectToken(is, binary,
"</DiscriminativeSupervision>");
108 if (num_ali.size() == 0)
return false;
109 if (den_lat.NumStates() == 0)
return false;
130 std::vector<int32> state_times;
140 config_(config), tmodel_(tmodel), supervision_(supervision) {
142 KALDI_WARN <<
"Splitting already-reattached sequence (only expected in " 156 KALDI_ASSERT(start_state == 0 &&
"Expecting start-state to be 0");
170 const std::vector<int32> &state_times,
Lattice *lat)
const {
174 int32 num_frames = state_times.back();
175 StateId num_states = lat->NumStates();
177 std::vector<std::map<int32, int32> > pdf_to_tid(num_frames);
178 for (StateId s = 0; s < num_states; s++) {
179 int32 t = state_times[s];
180 for (fst::MutableArcIterator<Lattice> aiter(lat, s);
181 !aiter.Done(); aiter.Next()) {
183 Arc arc = aiter.Value();
184 KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel);
186 if (pdf_to_tid[t].
count(pdf) != 0) {
187 arc.ilabel = arc.olabel = pdf_to_tid[t][pdf];
190 pdf_to_tid[t][pdf] = arc.ilabel;
199 state_times.size() == beta.size());
208 int32 end_frame = begin_frame + num_frames;
212 begin_frame + num_frames <=
217 begin_frame, end_frame, normalize,
220 out_supervision->
num_ali.clear();
223 std::back_inserter(out_supervision->
num_ali));
229 out_supervision->
Check();
234 int32 begin_frame,
int32 end_frame,
bool normalize,
238 const std::vector<int32> &state_times = scores.
state_times;
242 if (!in_lat.Properties(fst::kTopSorted,
true))
243 KALDI_ERR <<
"Input lattice must be topologically sorted.";
245 std::vector<int32>::const_iterator begin_iter =
246 std::lower_bound(state_times.begin(), state_times.end(), begin_frame),
247 end_iter = std::lower_bound(begin_iter,
248 state_times.end(), end_frame);
251 (begin_iter == state_times.begin() ||
252 begin_iter[-1] < begin_frame));
256 (end_iter < state_times.end() || *end_iter == end_frame));
257 StateId begin_state = begin_iter - state_times.begin(),
258 end_state = end_iter - state_times.begin();
261 out_lat->DeleteStates();
262 out_lat->ReserveStates(end_state - begin_state + 2);
265 StateId start_state = out_lat->AddState();
266 out_lat->SetStart(start_state);
268 for (StateId
i = begin_state;
i < end_state;
i++)
272 StateId final_state = out_lat->AddState();
275 for (StateId state = begin_state; state < end_state; state++) {
276 StateId output_state = state - begin_state + 1;
277 if (state_times[state] == begin_frame) {
294 out_lat->AddArc(start_state,
299 for (fst::ArcIterator<Lattice> aiter(in_lat, state);
300 !aiter.Done(); aiter.Next()) {
302 StateId nextstate = arc.nextstate;
303 if (nextstate >= end_state) {
310 weight.
SetValue1(arc.weight.Value1() - scores.
beta[nextstate]);
317 out_lat->AddArc(output_state,
318 LatticeArc(arc.ilabel, arc.olabel, weight, final_state));
320 StateId output_nextstate = nextstate - begin_state + 1;
321 out_lat->AddArc(output_state,
322 LatticeArc(arc.ilabel, arc.olabel, arc.weight, output_nextstate));
329 fst::Project(out_lat, fst::PROJECT_INPUT);
330 fst::RmEpsilon(out_lat);
338 fst::Determinize(*out_lat, &tmp_lat);
342 fst::Reverse(*out_lat, &tmp_lat);
343 fst::Determinize(tmp_lat, out_lat);
344 fst::Reverse(*out_lat, &tmp_lat);
345 fst::Determinize(tmp_lat, out_lat);
346 fst::RmEpsilon(out_lat);
350 fst::TopSort(out_lat);
351 std::vector<int32> state_times_tmp;
353 end_frame - begin_frame);
374 int32 num_states = lat->NumStates();
375 std::vector<std::pair<int32,int32> > state_time_indexes(num_states);
376 for (
int32 s = 0; s < num_states; s++) {
377 state_time_indexes[s] = std::make_pair(scores->
state_times[s], s);
382 std::sort(state_time_indexes.begin(), state_time_indexes.end());
384 std::vector<int32> state_order(num_states);
385 for (
int32 s = 0; s < num_states; s++) {
386 state_order[state_time_indexes[s].second] = s;
389 fst::StateSort(lat, state_order);
405 int32 num_inputs = input.size();
406 if (num_inputs == 1) {
407 *output_supervision = *(input[0]);
410 *output_supervision = *(input[num_inputs-1]);
411 for (
int32 i = num_inputs - 2;
i >= 0;
i--) {
421 output_supervision->
num_ali.insert(
422 output_supervision->
num_ali.begin(),
427 KALDI_ERR <<
"Mismatch weight or frames_per_sequence between inputs";
431 fst::TopSort(&(out_sup.
den_lat));
fst::StdArc::StateId StateId
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void GetFrameRange(int32 begin_frame, int32 frames_per_sequence, bool normalize, DiscriminativeSupervision *supervision) const
fst::ArcTpl< LatticeWeight > LatticeArc
bool ReadLattice(std::istream &is, bool binary, Lattice **lat)
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
void ComputeLatticeScores(const Lattice &lat, LatticeInfo *scores) const
double ComputeLatticeAlphasAndBetas(const LatticeType &lat, bool viterbi, vector< double > *alpha, vector< double > *beta)
static const LatticeWeightTpl One()
void CreateRangeLattice(const Lattice &in_lat, const LatticeInfo &scores, int32 begin_frame, int32 end_frame, bool normalize, Lattice *out_lat) const
const SplitDiscriminativeSupervisionOptions & config_
void Write(std::ostream &os, bool binary) const
std::vector< double > alpha
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
void CollapseTransitionIds(const std::vector< int32 > &state_times, Lattice *lat) const
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
const TransitionModel & tmodel_
void MergeSupervision(const std::vector< const DiscriminativeSupervision *> &input, DiscriminativeSupervision *output_supervision)
This function appends a list of supervision objects to create what will usually be a single such obje...
std::vector< double > beta
int32 TransitionIdToPdf(int32 trans_id) const
bool operator==(const DiscriminativeSupervision &other) const
LatticeInfo den_lat_scores_
void Swap(DiscriminativeSupervision *other)
bool collapse_transition_ids
std::vector< std::vector< double > > AcousticLatticeScale(double acwt)
void PrepareLattice(Lattice *lat, LatticeInfo *scores) const
void ReadIntegerVector(std::istream &is, bool binary, std::vector< T > *v)
Function for reading STL vector of integer types.
void ScaleLattice(const std::vector< std::vector< ScaleFloat > > &scale, MutableFst< ArcTpl< Weight > > *fst)
Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by viewing the pair (a...
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
std::vector< int32 > num_ali
const DiscriminativeSupervision & supervision_
fst::VectorFst< LatticeArc > Lattice
void Read(std::istream &is, bool binary)
int32 frames_per_sequence
DiscriminativeSupervision()
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
std::vector< int32 > state_times
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t)
bool IsSorted(const std::vector< T > &vec)
Returns true if the vector is sorted.
#define KALDI_ASSERT(cond)
bool Initialize(const std::vector< int32 > &alignment, const Lattice &lat, BaseFloat weight)
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector< T > &v)
Function for writing STL vectors of integer types.
fst::VectorFst< LatticeArc > Lattice
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
DiscriminativeSupervisionSplitter(const SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const DiscriminativeSupervision &supervision)