53 if (token !=
"</NnetChainSup>") {
77 frames_per_sequence =
supervision.frames_per_sequence;
79 for (
int32 i = 0;
i < frames_per_sequence;
i++) {
80 for (
int32 j = 0;
j < num_sequences;
j++,k++) {
81 int32 n =
j, t =
i * frame_skip + first_frame, x = 0;
108 const std::string &
name,
114 supervision(supervision),
115 deriv_weights(deriv_weights) {
117 indexes.resize(supervision.num_sequences *
118 supervision.frames_per_sequence);
119 int32 k = 0, num_sequences = supervision.num_sequences,
120 frames_per_sequence = supervision.frames_per_sequence;
121 for (
int32 i = 0;
i < frames_per_sequence;
i++) {
122 for (
int32 j = 0;
j < num_sequences;
j++,k++) {
124 indexes[k].t =
i * frame_skip + first_frame;
137 int32 size = inputs.size();
139 KALDI_ASSERT(size > 0 &&
"Attempting to write NnetChainExample with no inputs");
140 if (!binary) os <<
'\n';
142 inputs[
i].Write(os, binary);
143 if (!binary) os <<
'\n';
146 size = outputs.size();
148 KALDI_ASSERT(size > 0 &&
"Attempting to write NnetChainExample with no outputs");
149 if (!binary) os <<
'\n';
151 outputs[
i].Write(os, binary);
152 if (!binary) os <<
'\n';
162 if (size < 1 || size > 1000000)
166 inputs[
i].
Read(is, binary);
169 if (size < 1 || size > 1000000)
171 outputs.resize(size);
173 outputs[
i].
Read(is, binary);
178 inputs.swap(other->
inputs);
183 std::vector<NnetIo>::iterator iter = inputs.begin(), end = inputs.end();
186 for (; iter != end; ++iter) iter->features.Compress();
190 inputs(other.inputs), outputs(other.outputs) { }
196 const std::vector<const NnetChainSupervision*> &
inputs,
198 int32 num_inputs = inputs.size(),
200 for (
int32 n = 0;
n < num_inputs;
n++) {
202 num_indexes += inputs[
n]->indexes.size();
204 output->
name = inputs[0]->name;
205 std::vector<const chain::Supervision*> input_supervision;
206 input_supervision.reserve(inputs.size());
207 for (
int32 n = 0;
n < num_inputs;
n++)
208 input_supervision.push_back(&(inputs[
n]->supervision));
209 chain::Supervision output_supervision;
211 &output_supervision);
215 output->
indexes.reserve(num_indexes);
216 for (
int32 n = 0; n < num_inputs; n++) {
217 const std::vector<Index> &src_indexes = inputs[
n]->indexes;
220 src_indexes.begin(), src_indexes.end());
221 std::vector<Index>::iterator iter = output->
indexes.begin() + cur_size,
225 for (; iter != end; ++iter) {
226 KALDI_ASSERT(iter->n == 0 &&
"Merging already-merged chain egs");
237 if (inputs[0]->deriv_weights.Dim() != 0) {
238 int32 frames_per_sequence = inputs[0]->deriv_weights.Dim();
241 frames_per_sequence * num_inputs);
242 for (
int32 n = 0; n < num_inputs; n++) {
247 for (
int32 t = 0; t < frames_per_sequence; t++) {
248 output->
deriv_weights(t * num_inputs + n) = src_deriv_weights(t);
257 std::vector<NnetChainExample> *input,
259 int32 num_examples = input->size();
263 std::vector<NnetExample> eg_inputs(num_examples);
264 for (
int32 i = 0;
i < num_examples;
i++)
265 eg_inputs[
i].io.swap((*input)[
i].inputs);
269 for (
int32 i = 0; i < num_examples; i++)
270 eg_inputs[i].io.swap((*input)[i].inputs);
277 int32 num_output_names = (*input)[0].outputs.size();
278 output->
outputs.resize(num_output_names);
279 for (
int32 i = 0; i < num_output_names; i++) {
280 std::vector<const NnetChainSupervision*> to_merge(num_examples);
281 for (
int32 j = 0;
j < num_examples;
j++) {
283 to_merge[
j] = &((*input)[
j].outputs[
i]);
292 bool need_model_derivative,
293 bool store_component_stats,
294 bool use_xent_regularization,
295 bool use_xent_derivative,
303 for (
size_t i = 0;
i < eg.
inputs.size();
i++) {
305 const std::string &name = io.
name;
307 if (node_index == -1 ||
309 KALDI_ERR <<
"Nnet example has input named '" << name
310 <<
"', but no such input node is in the network.";
318 for (
size_t i = 0;
i < eg.
outputs.size();
i++) {
321 const std::string &name = sup.
name;
323 if (node_index == -1 &&
325 KALDI_ERR <<
"Nnet example has output named '" << name
326 <<
"', but no such output node is in the network.";
331 io_spec.
has_deriv = need_model_derivative;
333 if (use_xent_regularization) {
334 size_t cur_size = request->
outputs.size();
335 request->
outputs.resize(cur_size + 1);
337 &io_spec_xent = request->
outputs[cur_size];
341 io_spec_xent = io_spec;
342 io_spec_xent.
name = name +
"-xent";
343 io_spec_xent.has_deriv = use_xent_derivative;
347 if (request->
inputs.empty())
348 KALDI_ERR <<
"No inputs in computation request.";
350 KALDI_ERR <<
"No outputs in computation request.";
354 const std::vector<std::string> &exclude_names,
356 std::vector<NnetIo>::iterator input_iter = eg->
inputs.begin(),
357 input_end = eg->
inputs.end();
358 for (; input_iter != input_end; ++input_iter) {
359 bool must_exclude =
false;
360 std::vector<std::string>::const_iterator exclude_iter = exclude_names.begin(),
361 exclude_end = exclude_names.end();
362 for (; exclude_iter != exclude_end; ++exclude_iter)
363 if (input_iter->name == *exclude_iter)
366 std::vector<Index>::iterator indexes_iter = input_iter->indexes.begin(),
367 indexes_end = input_iter->indexes.end();
368 for (; indexes_iter != indexes_end; ++indexes_iter)
369 indexes_iter->t += frame_shift;
375 std::vector<NnetChainSupervision>::iterator
376 sup_iter = eg->
outputs.begin(),
378 for (; sup_iter != sup_end; ++sup_iter) {
379 std::vector<Index> &indexes = sup_iter->indexes;
380 KALDI_ASSERT(indexes.size() >= 2 && indexes[0].n == indexes[1].n &&
381 indexes[0].x == indexes[1].x);
382 int32 frame_subsampling_factor = indexes[1].t - indexes[0].t;
387 int32 supervision_frame_shift =
388 frame_subsampling_factor *
389 std::floor(0.5 + (frame_shift * 1.0 / frame_subsampling_factor));
390 if (supervision_frame_shift == 0)
392 std::vector<Index>::iterator indexes_iter = indexes.begin(),
393 indexes_end = indexes.end();
394 for (; indexes_iter != indexes_end; ++indexes_iter)
395 indexes_iter->t += supervision_frame_shift;
404 size_t size = eg.inputs.size(), ans = size * 35099;
405 for (
size_t i = 0;
i < size;
i++)
406 ans = ans * 19157 + io_hasher(eg.inputs[
i]);
407 for (
size_t i = 0;
i < eg.outputs.size();
i++) {
412 string_hasher(sup.
name) + indexes_hasher(sup.
indexes);
424 size_t size = a.
inputs.size();
425 for (
size_t i = 0;
i < size;
i++)
429 for (
size_t i = 0;
i < size;
i++)
439 for (
size_t i = 0;
i < a.
inputs.size();
i++) {
444 for (
size_t i = 0;
i < a.
outputs.size();
i++) {
455 finished_(false), num_egs_written_(0),
456 config_(config), writer_(writer) { }
466 std::vector<NnetChainExample*> &vec =
eg_to_egs_[eg];
469 num_available = vec.size();
470 bool input_ended =
false;
473 if (minibatch_size != 0) {
476 std::vector<NnetChainExample*> vec_copy(vec);
481 std::vector<NnetChainExample> egs_to_merge(minibatch_size);
482 for (
int32 i = 0;
i < minibatch_size;
i++) {
483 egs_to_merge[
i].Swap(vec_copy[
i]);
491 std::vector<NnetChainExample> *egs) {
495 size_t structure_hash = eg_hasher((*egs)[0]);
496 int32 minibatch_size = egs->size();
500 std::ostringstream key;
511 std::vector<std::vector<NnetChainExample*> > all_egs;
515 for (; iter != end; ++iter)
516 all_egs.push_back(iter->second);
519 for (
size_t i = 0;
i < all_egs.size();
i++) {
520 int32 minibatch_size;
521 std::vector<NnetChainExample*> &vec = all_egs[
i];
524 bool input_ended =
true;
525 while (!vec.empty() &&
527 input_ended)) != 0) {
531 std::vector<NnetChainExample> egs_to_merge(minibatch_size);
532 for (
int32 i = 0;
i < minibatch_size;
i++) {
533 egs_to_merge[
i].Swap(vec[
i]);
536 vec.erase(vec.begin(), vec.begin() + minibatch_size);
542 size_t structure_hash = eg_hasher(*(vec[0]));
543 int32 num_discarded = vec.size();
545 for (
int32 i = 0;
i < num_discarded;
i++)
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void WriteIndexVector(std::ostream &os, bool binary, const std::vector< Index > &vec)
Vector< BaseFloat > deriv_weights
This is a vector of per-frame weights, required to be between 0 and 1, that is applied to the derivat...
void DiscardedExamples(int32 example_size, size_t structure_hash, int32 num_discarded)
Users call this function to inform this class that after processing all the data, for examples of ori...
bool store_component_stats
you should set need_component_stats to true if you need the average-activation and average-derivative...
bool need_model_derivative
if need_model_derivative is true, then we'll be doing either model training or model-derivative compu...
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
chain::Supervision supervision
The supervision object, containing the FST.
std::vector< NnetIo > inputs
'inputs' contains the input to the network– normally just it has just one element called "input"...
void MergeChainExamples(bool compress, std::vector< NnetChainExample > *input, NnetChainExample *output)
This function merges a list of NnetChainExample objects into a single one– intended to be used when ...
void ShiftChainExampleTimes(int32 frame_shift, const std::vector< std::string > &exclude_names, NnetChainExample *eg)
Shifts the time-index t of everything in the input of "eg" by adding "t_offset" to all "t" values– b...
bool IsInputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kInput.
static void MergeSupervision(const std::vector< const NnetChainSupervision *> &inputs, NnetChainSupervision *output)
int32 MinibatchSize(int32 size_of_eg, int32 num_available_egs, bool input_ended) const
This function tells you what minibatch size should be used for this eg.
A templated class for writing objects to an archive or script file; see The Table concept...
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
std::vector< IoSpecification > inputs
std::vector< Index > indexes
"indexes" is a vector the same length as features.NumRows(), explaining the meaning of each row of th...
void Write(std::ostream &os, bool binary) const
void Swap(NnetChainSupervision *other)
A hashing function object for strings.
std::string name
the name of the output in the neural net; in simple setups it will just be "output".
void Write(const std::string &key, const T &value) const
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
int32 GetNnetChainExampleSize(const NnetChainExample &a)
bool operator()(const NnetChainExample &a, const NnetChainExample &b) const
void PrintStats() const
Calling this will cause a log message with information about the examples to be printed.
This hashing object hashes just the structural aspects of the NnetExample without looking at the valu...
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
std::vector< NnetChainSupervision > outputs
'outputs' contains the chain output supervision.
ChainExampleMerger(const ExampleMergingConfig &config, NnetChainExampleWriter *writer)
void Read(std::istream &is, bool binary)
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
void AcceptExample(NnetChainExample *a)
size_t operator()(const NnetChainExample &eg) const noexcept
void Swap(NnetChainExample *other)
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
MatrixIndexT Dim() const
Returns the dimension of the vector.
void ReadIndexVector(std::istream &is, bool binary, std::vector< Index > *vec)
void Read(std::istream &is, bool binary)
void WroteExample(int32 example_size, size_t structure_hash, int32 minibatch_size)
Users call this function to inform this class that one minibatch has been written aggregating 'miniba...
const ExampleMergingConfig & config_
NnetChainExampleWriter * writer_
bool operator==(const NnetChainSupervision &other) const
std::vector< Index > indexes
void Write(std::ostream &os, bool binary) const
ExampleMergingStats stats_
A class representing a vector.
#define KALDI_ASSERT(cond)
std::vector< IoSpecification > outputs
This comparison object compares just the structural aspects of the NnetIo object (name, indexes, feature dimension) without looking at the value of features.
std::vector< Index > indexes
The indexes that the output corresponds to.
This hashing object hashes just the structural aspects of the NnetIo object (name, indexes, feature dimension) without looking at the value of features.
void ReadVectorAsChar(std::istream &is, bool binary, Vector< BaseFloat > *vec)
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
void WriteMinibatch(std::vector< NnetChainExample > *egs)
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Provides a vector abstraction class.
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.
std::vector< NnetIo > io
"io" contains the input and output.
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
void MergeExamples(const std::vector< NnetExample > &src, bool compress, NnetExample *merged_eg)
Merge a set of input examples into a single example (typically the size of "src" will be the minibatc...