37 const std::string &utt_id,
40 int32 length_tolerance,
44 if (!utt_splitter->
LengthsMatch(utt_id, num_input_frames,
45 static_cast<int32>(pdf_post.size()),
49 std::vector<ChunkTimeInfo> chunks;
54 KALDI_WARN <<
"Not producing egs for utterance " << utt_id
55 <<
" because it is too short: " 56 << num_input_frames <<
" frames.";
62 int32 frame_subsampling_factor =
65 for (
size_t c = 0; c < chunks.size(); c++) {
86 if (ivector_feats != NULL) {
90 start_frame + num_input_frames - 1),
91 ivector_frame_subsampled = ivector_frame / ivector_period;
92 if (ivector_frame_subsampled < 0)
93 ivector_frame_subsampled = 0;
94 if (ivector_frame_subsampled >= ivector_feats->
NumRows())
95 ivector_frame_subsampled = ivector_feats->
NumRows() - 1;
97 ivector.
Row(0).CopyFromVec(ivector_feats->
Row(ivector_frame_subsampled));
98 eg.
io.push_back(
NnetIo(
"ivector", 0, ivector));
103 int32 start_frame_subsampled = chunk.
first_frame / frame_subsampling_factor,
104 num_frames_subsampled = chunk.
num_frames / frame_subsampling_factor;
113 for (
int32 i = 0;
i < num_frames_subsampled;
i++) {
114 int32 t =
i + start_frame_subsampled;
115 if (t < pdf_post.size())
116 labels[
i] = pdf_post[t];
117 for (std::vector<std::pair<int32, BaseFloat> >::iterator
118 iter = labels[
i].begin(); iter != labels[
i].end(); ++iter)
122 eg.
io.push_back(
NnetIo(
"output", num_pdfs, 0, labels, frame_subsampling_factor));
127 std::ostringstream os;
130 std::string key = os.str();
132 example_writer->
Write(key, eg);
140 int main(
int argc,
char *argv[]) {
142 using namespace kaldi;
145 typedef kaldi::int64 int64;
148 "Get frame-by-frame examples of data for nnet3 neural network training.\n" 149 "Essentially this is a format change from features and posteriors\n" 150 "into a special frame-by-frame format. This program handles the\n" 151 "common case where you have some input features, possibly some\n" 152 "iVectors, and one set of labels. If people in future want to\n" 153 "do different things they may have to extend this program or create\n" 154 "different versions of it for different tasks (the egs format is quite\n" 157 "Usage: nnet3-get-egs [options] <features-rspecifier> " 158 "<pdf-post-rspecifier> <egs-out>\n" 160 "An example [where $feats expands to the actual features]:\n" 161 "nnet3-get-egs --num-pdfs=2658 --left-context=12 --right-context=9 --num-frames=8 \"$feats\"\\\n" 162 "\"ark:gunzip -c exp/nnet/ali.1.gz | ali-to-pdf exp/nnet/1.nnet ark:- ark:- | ali-to-post ark:- ark:- |\" \\\n" 164 "See also: nnet3-chain-get-egs, nnet3-get-egs-simple\n";
167 bool compress =
true;
168 int32 num_pdfs = -1, length_tolerance = 100,
169 targets_length_tolerance = 2,
170 online_ivector_period = 1;
175 std::string online_ivector_rspecifier;
179 po.
Register(
"compress", &compress,
"If true, write egs with input features " 180 "in compressed format (recommended). This is " 181 "only relevant if the features being read are un-compressed; " 182 "if already compressed, we keep we same compressed format when " 184 po.
Register(
"num-pdfs", &num_pdfs,
"Number of pdfs in the acoustic " 186 po.
Register(
"ivectors", &online_ivector_rspecifier,
"Alias for " 187 "--online-ivectors option, for back compatibility");
188 po.
Register(
"online-ivectors", &online_ivector_rspecifier,
"Rspecifier of " 189 "ivector features, as a matrix.");
190 po.
Register(
"online-ivector-period", &online_ivector_period,
"Number of " 191 "frames between iVectors in matrices supplied to the " 192 "--online-ivectors option");
193 po.
Register(
"length-tolerance", &length_tolerance,
"Tolerance for " 194 "difference in num-frames between feat and ivector matrices");
195 po.
Register(
"targets-length-tolerance", &targets_length_tolerance,
197 "difference in num-frames (after subsampling) between " 198 "feature matrix and posterior");
209 KALDI_ERR <<
"--num-pdfs options is required.";
214 std::string feature_rspecifier = po.
GetArg(1),
215 pdf_post_rspecifier = po.
GetArg(2),
216 examples_wspecifier = po.
GetArg(3);
226 online_ivector_rspecifier);
230 for (; !feat_reader.
Done(); feat_reader.
Next()) {
231 std::string key = feat_reader.
Key();
233 if (!pdf_post_reader.
HasKey(key)) {
234 KALDI_WARN <<
"No pdf-level posterior for key " << key;
239 if (!online_ivector_rspecifier.empty()) {
240 if (!online_ivector_reader.
HasKey(key)) {
241 KALDI_WARN <<
"No iVectors for utterance " << key;
247 online_ivector_feats = &(online_ivector_reader.
Value(key));
251 if (online_ivector_feats != NULL &&
253 online_ivector_period)) > length_tolerance
254 || online_ivector_feats->
NumRows() == 0)) {
256 <<
" and iVectors " << online_ivector_feats->
NumRows()
257 <<
"exceeds tolerance " << length_tolerance;
262 if (!
ProcessFile(feats, online_ivector_feats, online_ivector_period,
263 pdf_post, key, compress, num_pdfs,
264 targets_length_tolerance,
265 &utt_splitter, &example_writer))
270 KALDI_WARN << num_err <<
" utterances had errors and could " 274 }
catch(
const std::exception &e) {
275 std::cerr << e.what() <<
'\n';
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
This class is a wrapper that enables you to store a matrix in one of three forms: either as a Matrix<...
bool LengthsMatch(const std::string &utt, int32 utterance_length, int32 supervision_length, int32 length_tolerance=0) const
int32 frame_subsampling_factor
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Base class which provides matrix operations not involving resizing or allocation. ...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
const T & Value(const std::string &key)
static bool ProcessFile(const discriminative::SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const MatrixBase< BaseFloat > &feats, const MatrixBase< BaseFloat > *ivector_feats, int32 ivector_period, const discriminative::DiscriminativeSupervision &supervision, const std::string &utt_id, bool compress, UtteranceSplitter *utt_splitter, NnetDiscriminativeExampleWriter *example_writer)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Register(OptionsItf *po)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void ExtractRowRangeWithPadding(const GeneralMatrix &in, int32 row_offset, int32 num_rows, GeneralMatrix *out)
This function extracts a row-range of a GeneralMatrix and writes as a GeneralMatrix containing the sa...
void Compress()
Compresses any (input) features that are not sparse.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
const ExampleGenerationConfig & Config() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
MatrixIndexT NumRows() const
void GetChunksForUtterance(int32 utterance_length, std::vector< ChunkTimeInfo > *chunk_info)
struct ChunkTimeInfo is used by class UtteranceSplitter to output information about how we split an u...
std::vector< NnetIo > io
"io" contains the input and output.
int main(int argc, char *argv[])
std::vector< BaseFloat > output_weights
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
void ComputeDerived()
This function decodes 'num_frames_str' into 'num_frames', and ensures that the members of 'num_frames...