nnet3-xvector-get-egs.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-xvector-get-egs.cc
2 
3 // Copyright 2016-2017 Johns Hopkins University (author: Daniel Povey)
4 // 2016-2017 Johns Hopkins University (author: Daniel Garcia-Romero)
5 // 2016-2017 David Snyder
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include <sstream>
23 #include "util/common-utils.h"
24 #include "nnet3/nnet-example.h"
25 
26 namespace kaldi {
27 namespace nnet3 {
28 
29 // A struct for holding information about the position and
30 // duration of each chunk.
31 struct ChunkInfo {
32  std::string name;
37 };
38 
39 // Process the range input file and store it as a map from utterance
40 // name to vector of ChunkInfo structs.
41 static void ProcessRangeFile(const std::string &range_rxfilename,
42  unordered_map<std::string, std::vector<ChunkInfo *> > *utt_to_chunks) {
43  Input range_input(range_rxfilename);
44  if (!range_rxfilename.empty()) {
45  std::string line;
46  while (std::getline(range_input.Stream(), line)) {
47  ChunkInfo *chunk_info = new ChunkInfo();
48  std::vector<std::string> fields;
49  SplitStringToVector(line, " \t\n\r", true, &fields);
50  if (fields.size() != 6)
51  KALDI_ERR << "Expected 6 fields in line of range file, got "
52  << fields.size() << " instead.";
53 
54  std::string utt = fields[0],
55  start_frame_str = fields[3],
56  num_frames_str = fields[4],
57  label_str = fields[5];
58 
59  if (!ConvertStringToInteger(fields[1], &(chunk_info->output_archive_id))
60  || !ConvertStringToInteger(start_frame_str, &(chunk_info->start_frame))
61  || !ConvertStringToInteger(num_frames_str, &(chunk_info->num_frames))
62  || !ConvertStringToInteger(label_str, &(chunk_info->label)))
63  KALDI_ERR << "Expected integer for output archive in range file.";
64 
65  chunk_info->name = utt + "-" + start_frame_str + "-" + num_frames_str
66  + "-" + label_str;
67  unordered_map<std::string, std::vector<ChunkInfo*> >::iterator
68  got = utt_to_chunks->find(utt);
69 
70  if (got == utt_to_chunks->end()) {
71  std::vector<ChunkInfo* > chunk_infos;
72  chunk_infos.push_back(chunk_info);
73  utt_to_chunks->insert(std::pair<std::string,
74  std::vector<ChunkInfo* > > (utt, chunk_infos));
75  } else {
76  got->second.push_back(chunk_info);
77  }
78  }
79  }
80 }
81 
82 static void WriteExamples(const MatrixBase<BaseFloat> &feats,
83  const std::vector<ChunkInfo *> &chunks, const std::string &utt,
84  bool compress, int32 num_pdfs, int32 *num_egs_written,
85  std::vector<NnetExampleWriter *> *example_writers) {
86  for (std::vector<ChunkInfo *>::const_iterator it = chunks.begin();
87  it != chunks.end(); ++it) {
88  ChunkInfo *chunk = *it;
89  NnetExample eg;
90  int32 num_rows = feats.NumRows(),
91  feat_dim = feats.NumCols();
92  if (num_rows < chunk->num_frames) {
93  KALDI_WARN << "Unable to create examples for utterance " << utt
94  << ". Requested chunk size of "
95  << chunk->num_frames
96  << " but utterance has only " << num_rows << " frames.";
97  } else {
98  // The requested chunk positions are approximate. It's possible
99  // that they slightly exceed the number of frames in the utterance.
100  // If that occurs, we can shift the chunks location back slightly.
101  int32 shift = std::min(0, num_rows - chunk->start_frame
102  - chunk->num_frames);
103  SubMatrix<BaseFloat> chunk_mat(feats, chunk->start_frame + shift,
104  chunk->num_frames, 0, feat_dim);
105  NnetIo nnet_input = NnetIo("input", 0, chunk_mat);
106  for (std::vector<Index>::iterator indx_it = nnet_input.indexes.begin();
107  indx_it != nnet_input.indexes.end(); ++indx_it)
108  indx_it->n = 0;
109 
111  std::vector<std::pair<int32, BaseFloat> > post;
112  post.push_back(std::pair<int32, BaseFloat>(chunk->label, 1.0));
113  label.push_back(post);
114  NnetExample eg;
115  eg.io.push_back(nnet_input);
116  eg.io.push_back(NnetIo("output", num_pdfs, 0, label));
117  if (compress)
118  eg.Compress();
119 
120  if (chunk->output_archive_id >= example_writers->size())
121  KALDI_ERR << "Requested output index exceeds number of specified "
122  << "output files.";
123  (*example_writers)[chunk->output_archive_id]->Write(
124  chunk->name, eg);
125  (*num_egs_written) += 1;
126  }
127  }
128 }
129 
130 } // namespace nnet3
131 } // namespace kaldi
132 
133 int main(int argc, char *argv[]) {
134  try {
135  using namespace kaldi;
136  using namespace kaldi::nnet3;
137  typedef kaldi::int32 int32;
138 
139  const char *usage =
140  "Get examples for training an nnet3 neural network for the xvector\n"
141  "system. Each output example contains a chunk of features from some\n"
142  "utterance along with a speaker label. The location and length of\n"
143  "the feature chunks are specified in the 'ranges' file. Each line\n"
144  "is interpreted as follows:\n"
145  " <source-utterance> <relative-output-archive-index> "
146  "<absolute-archive-index> <start-frame-index> <num-frames> "
147  "<speaker-label>\n"
148  "where <relative-output-archive-index> is interpreted as a zero-based\n"
149  "index into the wspecifiers provided on the command line (<egs-0-out>\n"
150  "and so on), and <absolute-archive-index> is ignored by this program.\n"
151  "For example:\n"
152  " utt1 3 13 65 300 3\n"
153  " utt1 0 10 50 400 3\n"
154  " utt2 ...\n"
155  "\n"
156  "Usage: nnet3-xvector-get-egs [options] <ranges-filename> "
157  "<features-rspecifier> <egs-0-out> <egs-1-out> ... <egs-N-1-out>\n"
158  "\n"
159  "For example:\n"
160  "nnet3-xvector-get-egs ranges.1 \"$feats\" ark:egs_temp.1.ark"
161  " ark:egs_temp.2.ark ark:egs_temp.3.ark\n";
162 
163  bool compress = true;
164  int32 num_pdfs = -1;
165 
166  ParseOptions po(usage);
167  po.Register("compress", &compress, "If true, write egs in "
168  "compressed format.");
169  po.Register("num-pdfs", &num_pdfs, "Number of speakers in the training "
170  "list.");
171 
172  po.Read(argc, argv);
173 
174  if (po.NumArgs() < 3) {
175  po.PrintUsage();
176  exit(1);
177  }
178 
179  std::string range_rspecifier = po.GetArg(1),
180  feature_rspecifier = po.GetArg(2);
181  std::vector<NnetExampleWriter *> example_writers;
182 
183  for (int32 i = 3; i <= po.NumArgs(); i++)
184  example_writers.push_back(new NnetExampleWriter(po.GetArg(i)));
185 
186  unordered_map<std::string, std::vector<ChunkInfo *> > utt_to_chunks;
187  ProcessRangeFile(range_rspecifier, &utt_to_chunks);
188  SequentialBaseFloatMatrixReader feat_reader(feature_rspecifier);
189 
190  int32 num_done = 0,
191  num_err = 0,
192  num_egs_written = 0;
193 
194  for (; !feat_reader.Done(); feat_reader.Next()) {
195  std::string key = feat_reader.Key();
196  const Matrix<BaseFloat> &feats = feat_reader.Value();
197  unordered_map<std::string, std::vector<ChunkInfo*> >::iterator
198  got = utt_to_chunks.find(key);
199  if (got == utt_to_chunks.end()) {
200  KALDI_WARN << "Could not create examples from utterance "
201  << key << " because it has no entry in the ranges "
202  << "input file.";
203  num_err++;
204  } else {
205  std::vector<ChunkInfo *> chunks = got->second;
206  WriteExamples(feats, chunks, key, compress, num_pdfs,
207  &num_egs_written, &example_writers);
208  num_done++;
209  }
210  }
211 
212  // Free memory
213  for (unordered_map<std::string, std::vector<ChunkInfo*> >::iterator
214  map_it = utt_to_chunks.begin();
215  map_it != utt_to_chunks.end(); ++map_it) {
216  DeletePointers(&map_it->second);
217  }
218  DeletePointers(&example_writers);
219 
220  KALDI_LOG << "Finished generating examples, "
221  << "successfully processed " << num_done
222  << " feature files, wrote " << num_egs_written << " examples; "
223  << num_err << " files had errors.";
224  return (num_egs_written == 0 || num_err > num_done ? 1 : 0);
225  } catch(const std::exception &e) {
226  std::cerr << e.what() << '\n';
227  return -1;
228  }
229 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
int main(int argc, char *argv[])
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool ConvertStringToInteger(const std::string &str, Int *out)
Converts a string into an integer via strtoll and returns false if there was any kind of problem (i...
Definition: text-utils.h:118
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
std::vector< Index > indexes
"indexes" is a vector the same length as features.NumRows(), explaining the meaning of each row of th...
Definition: nnet-example.h:42
void Register(const std::string &name, bool *ptr, const std::string &doc)
std::istream & Stream()
Definition: kaldi-io.cc:826
static void WriteExamples(const MatrixBase< BaseFloat > &feats, const std::vector< ChunkInfo *> &chunks, const std::string &utt, bool compress, int32 num_pdfs, int32 *num_egs_written, std::vector< NnetExampleWriter *> *example_writers)
std::vector< std::vector< std::pair< int32, BaseFloat > > > Posterior
Posterior is a typedef for storing acoustic-state (actually, transition-id) posteriors over an uttera...
Definition: posterior.h:42
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
void Compress()
Compresses any (input) features that are not sparse.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
int NumArgs() const
Number of positional parameters (c.f. argc-1).
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
#define KALDI_LOG
Definition: kaldi-error.h:153
Sub-matrix representation.
Definition: kaldi-matrix.h:988
static void ProcessRangeFile(const std::string &range_rxfilename, unordered_map< std::string, std::vector< ChunkInfo *> > *utt_to_chunks)