31 int32 ans = floor(expected_count);
32 expected_count -= ans;
39 int main(
int argc,
char *argv[]) {
41 using namespace kaldi;
44 typedef kaldi::int64 int64;
47 "Copy examples for nnet3 discriminative training, possibly changing the binary mode.\n" 48 "Supports multiple wspecifiers, in which case it will write the examples\n" 49 "round-robin to the outputs.\n" 51 "Usage: nnet3-discriminative-copy-egs [options] <egs-rspecifier> <egs-wspecifier1> [<egs-wspecifier2> ...]\n" 54 "nnet3-discriminative-copy-egs ark:train.degs ark,t:text.degs\n" 56 "nnet3-discriminative-copy-egs ark:train.degs ark:1.degs ark:2.degs\n";
60 int32 frame_shift = 0;
64 po.
Register(
"random", &random,
"If true, will write frames to output " 65 "archives randomly, not round-robin.");
66 po.
Register(
"keep-proportion", &keep_proportion,
"If <1.0, this program will " 67 "randomly keep this proportion of the input samples. If >1.0, it will " 68 "in expectation copy a sample this many times. It will copy it a number " 69 "of times equal to floor(keep-proportion) or ceil(keep-proportion).");
70 po.
Register(
"srand", &srand_seed,
"Seed for random number generator " 71 "(only relevant if --random=true or --keep-proportion != 1.0)");
72 po.
Register(
"frame-shift", &frame_shift,
"Allows you to shift time values " 73 "in the supervision data (excluding iVector data) - useful in " 74 "augmenting data. Note, the outputs will remain at the closest " 75 "exact multiples of the frame subsampling factor");
86 std::string examples_rspecifier = po.
GetArg(1);
90 int32 num_outputs = po.
NumArgs() - 1;
91 std::vector<NnetDiscriminativeExampleWriter*> example_writers(num_outputs);
92 for (int32
i = 0;
i < num_outputs;
i++)
95 std::vector<std::string> exclude_names;
97 exclude_names.push_back(std::string(
"ivector"));
100 int64 num_read = 0, num_written = 0;
101 for (; !example_reader.
Done(); example_reader.
Next(), num_read++) {
104 std::string key = example_reader.
Key();
105 if (frame_shift == 0) {
107 for (int32 c = 0; c <
count; c++) {
108 int32 index = (random ?
Rand() : num_written) % num_outputs;
109 example_writers[index]->Write(key, eg);
112 }
else if (count > 0) {
114 if (frame_shift != 0)
116 for (int32 c = 0; c <
count; c++) {
117 int32 index = (random ?
Rand() : num_written) % num_outputs;
118 example_writers[index]->Write(key, eg);
123 for (int32
i = 0;
i < num_outputs;
i++)
124 delete example_writers[
i];
126 <<
" neural-network training examples, wrote " << num_written;
127 return (num_written == 0 ? 1 : 0);
128 }
catch(
const std::exception &e) {
129 std::cerr << e.what() <<
'\n';
void ShiftDiscriminativeExampleTimes(int32 frame_shift, const std::vector< std::string > &exclude_names, NnetDiscriminativeExample *eg)
Shifts the time-index t of everything in the input of "eg" by adding "t_offset" to all "t" values– b...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
bool WithProb(BaseFloat prob, struct RandomState *state)
void Register(const std::string &name, bool *ptr, const std::string &doc)
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int32 GetCount(double expected_count)
TableWriter< KaldiObjectHolder< NnetDiscriminativeExample > > NnetDiscriminativeExampleWriter
int Rand(struct RandomState *state)
int NumArgs() const
Number of positional parameters (c.f. argc-1).
#define KALDI_ASSERT(cond)
int main(int argc, char *argv[])
int32 GetCount(double expected_count)
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.