30 #include "chain/chain-supervision.h" 45 const std::string &utt_id,
53 if (!utt_splitter->
LengthsMatch(utt_id, num_input_frames, num_output_frames))
56 std::vector<ChunkTimeInfo> chunks;
61 KALDI_WARN <<
"Not producing egs for utterance " << utt_id
62 <<
" because it is too short: " 63 << num_input_frames <<
" frames.";
71 for (
size_t c = 0; c < chunks.size(); c++) {
75 nnet_discriminative_eg.
outputs.resize(1);
78 num_frames_subsampled = chunk.
num_frames / frame_subsampling_factor;
83 num_frames_subsampled,
84 (c == 0 ?
false :
true),
91 int32 first_frame = 0;
96 frame_subsampling_factor);
97 nnet_discriminative_eg.
outputs[0].Swap(&nnet_supervision);
99 nnet_discriminative_eg.
inputs.resize(ivector_feats != NULL ? 2 : 1);
109 for (
int32 t = start_frame; t < start_frame + tot_input_frames; t++) {
112 if (t2 >= num_input_frames) t2 = num_input_frames - 1;
113 int32 j = t - start_frame;
115 dest(input_frames, j);
120 nnet_discriminative_eg.
inputs[0].Swap(&input_io);
122 if (ivector_feats != NULL) {
126 start_frame + num_input_frames - 1),
127 ivector_frame_subsampled = ivector_frame / ivector_period;
128 if (ivector_frame_subsampled < 0)
129 ivector_frame_subsampled = 0;
130 if (ivector_frame_subsampled >= ivector_feats->
NumRows())
131 ivector_frame_subsampled = ivector_feats->
NumRows() - 1;
133 ivector.
Row(0).CopyFromVec(ivector_feats->
Row(ivector_frame_subsampled));
134 NnetIo ivector_io(
"ivector", 0, ivector);
135 nnet_discriminative_eg.
inputs[1].Swap(&ivector_io);
141 std::ostringstream os;
144 std::string key = os.str();
146 example_writer->
Write(key, nnet_discriminative_eg);
155 int main(
int argc,
char *argv[]) {
157 using namespace kaldi;
160 typedef kaldi::int64 int64;
163 "Get frame-by-frame examples of data for nnet3+sequence neural network\n" 164 "training. This involves breaking up utterances into pieces of sizes\n" 165 "determined by the --num-frames option.\n" 167 "Usage: nnet3-discriminative-get-egs [options] <model> <features-rspecifier> " 168 "<denominator-lattice-rspecifier> <numerator-alignment-rspecifier> <egs-wspecifier>\n" 170 "An example [where $feats expands to the actual features]:\n" 171 " nnet3-discriminative-get-egs --left-context=25 --right-context=9 --num-frames=150,100,90 \\\n" 172 " \"$feats\" \"ark,s,cs:gunzip -c lat.1.gz\" scp:ali.scp ark:degs.1.ark\n";
174 bool compress =
true;
175 int32 length_tolerance = 100, online_ivector_period = 1;
177 std::string online_ivector_rspecifier;
186 po.
Register(
"compress", &compress,
"If true, write egs in " 187 "compressed format (recommended)");
188 po.
Register(
"ivectors", &online_ivector_rspecifier,
"Alias for --online-ivectors " 189 "option, for back compatibility");
190 po.
Register(
"online-ivectors", &online_ivector_rspecifier,
"Rspecifier of ivector " 191 "features, as a matrix.");
192 po.
Register(
"online-ivector-period", &online_ivector_period,
"Number of frames " 193 "between iVectors in matrices supplied to the --online-ivectors " 195 po.
Register(
"length-tolerance", &length_tolerance,
"Tolerance for " 196 "difference in num-frames between feat and ivector matrices");
210 std::string model_wxfilename = po.
GetArg(1),
211 feature_rspecifier = po.
GetArg(2),
212 den_lat_rspecifier = po.
GetArg(3),
213 num_ali_rspecifier = po.
GetArg(4),
214 examples_wspecifier = po.
GetArg(5);
220 Input ki(model_wxfilename, &binary);
229 online_ivector_rspecifier);
233 for (; !feat_reader.
Done(); feat_reader.
Next()) {
234 std::string key = feat_reader.
Key();
236 if (!den_lat_reader.
HasKey(key)) {
237 KALDI_WARN <<
"No denominator lattice for key " << key;
239 }
else if (!ali_reader.
HasKey(key)) {
240 KALDI_WARN <<
"No numerator alignment for key " << key;
245 den_lat_reader.
Value(key),
247 KALDI_WARN <<
"Failed to convert lattice to supervision " 248 <<
"for utterance " << key;
254 if (!online_ivector_rspecifier.empty()) {
255 if (!online_ivector_reader.
HasKey(key)) {
256 KALDI_WARN <<
"No iVectors for utterance " << key;
262 online_ivector_feats = &(online_ivector_reader.
Value(key));
265 if (online_ivector_feats != NULL &&
267 online_ivector_period)) > length_tolerance
268 || online_ivector_feats->
NumRows() == 0)) {
270 <<
" and iVectors " << online_ivector_feats->
NumRows()
271 <<
"exceeds tolerance " << length_tolerance;
276 feats, online_ivector_feats, online_ivector_period,
277 supervision, key, compress,
278 &utt_splitter, &example_writer))
283 KALDI_WARN << num_err <<
" utterances had errors and could " 287 }
catch(
const std::exception &e) {
288 std::cerr << e.what() <<
'\n';
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void GetFrameRange(int32 begin_frame, int32 frames_per_sequence, bool normalize, DiscriminativeSupervision *supervision) const
bool LengthsMatch(const std::string &utt, int32 utterance_length, int32 supervision_length, int32 length_tolerance=0) const
void Register(OptionsItf *opts)
int32 frame_subsampling_factor
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Base class which provides matrix operations not involving resizing or allocation. ...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
int main(int argc, char *argv[])
A templated class for writing objects to an archive or script file; see The Table concept...
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
void CopyFromVec(const VectorBase< Real > &v)
Copy data from another vector (must match own size).
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
const T & Value(const std::string &key)
static bool ProcessFile(const discriminative::SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const MatrixBase< BaseFloat > &feats, const MatrixBase< BaseFloat > *ivector_feats, int32 ivector_period, const discriminative::DiscriminativeSupervision &supervision, const std::string &utt_id, bool compress, UtteranceSplitter *utt_splitter, NnetDiscriminativeExampleWriter *example_writer)
void Read(std::istream &is, bool binary)
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
void Register(OptionsItf *po)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
int32 frames_per_sequence
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
bool HasKey(const std::string &key)
const ExampleGenerationConfig & Config() const
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
std::vector< NnetIo > inputs
'inputs' contains the input to the network– normally just it has just one element called "input"...
bool Initialize(const std::vector< int32 > &alignment, const Lattice &lat, BaseFloat weight)
std::vector< NnetDiscriminativeSupervision > outputs
'outputs' contains the sequence output supervision.
void GetChunksForUtterance(int32 utterance_length, std::vector< ChunkTimeInfo > *chunk_info)
struct ChunkTimeInfo is used by class UtteranceSplitter to output information about how we split an u...
std::vector< BaseFloat > output_weights
Represents a non-allocating general vector which can be defined as a sub-vector of higher-level vecto...
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
void ComputeDerived()
This function decodes 'num_frames_str' into 'num_frames', and ensures that the members of 'num_frames...