nnet3-egs-augment-image.cc
Go to the documentation of this file.
1 // nnet3bin/nnet3-egs-augment-image.cc
2 
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 // 2017 Hossein Hadian
5 // 2017 Yiwen Shao
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 #include "base/kaldi-common.h"
23 #include "util/common-utils.h"
24 #include "hmm/transition-model.h"
25 #include "nnet3/nnet-example.h"
27 
28 namespace kaldi {
29 namespace nnet3 {
30 
32 
40  std::string fill_mode_string;
41 
43  num_channels(1),
44  horizontal_flip_prob(0.0),
45  horizontal_shift(0.0),
46  vertical_shift(0.0),
47  rotation_degree(0.0),
48  rotation_prob(0.0),
49  fill_mode_string("nearest") { }
50 
51 
52  void Register(ParseOptions *po) {
53  po->Register("num-channels", &num_channels, "Number of colors in the image."
54  "It is important to specify this (helps interpret the image "
55  "correctly.");
56  po->Register("horizontal-flip-prob", &horizontal_flip_prob,
57  "Probability of doing horizontal flip");
58  po->Register("horizontal-shift", &horizontal_shift,
59  "Maximum allowed horizontal shift as proportion of image "
60  "width. Padding is with closest pixel.");
61  po->Register("vertical-shift", &vertical_shift,
62  "Maximum allowed vertical shift as proportion of image "
63  "height. Padding is with closest pixel.");
64  po->Register("rotation-degree", &rotation_degree,
65  "Maximum allowed degree to rotate the image");
66  po->Register("rotation-prob", &rotation_prob,
67  "Probability of doing rotation");
68  po->Register("fill-mode", &fill_mode_string, "Mode for dealing with "
69  "points outside the image boundary when applying transformation. "
70  "Choices = {nearest, reflect}");
71  }
72 
73  void Check() const {
74  KALDI_ASSERT(num_channels >= 1);
75  KALDI_ASSERT(horizontal_flip_prob >= 0 &&
76  horizontal_flip_prob <= 1);
77  KALDI_ASSERT(horizontal_shift >= 0 && horizontal_shift <= 1);
78  KALDI_ASSERT(vertical_shift >= 0 && vertical_shift <= 1);
79  KALDI_ASSERT(rotation_degree >=0 && rotation_degree <= 180);
80  KALDI_ASSERT(rotation_prob >=0 && rotation_prob <= 1);
81  KALDI_ASSERT(fill_mode_string == "nearest" || fill_mode_string == "reflect");
82  }
83 
85  FillMode fill_mode;
86  if (fill_mode_string == "reflect") {
87  fill_mode = kReflect;
88  } else {
89  if (fill_mode_string != "nearest") {
90  KALDI_ERR << "Choices for --fill-mode are 'nearest' or 'reflect', got: "
92  } else {
93  fill_mode = kNearest;
94  }
95  }
96  return fill_mode;
97  }
98 };
99 
112  MatrixBase<BaseFloat> *image,
113  FillMode fill_mode) {
114  int32 num_rows = image->NumRows(),
115  num_cols = image->NumCols(),
116  height = num_cols / num_channels,
117  width = num_rows;
118  KALDI_ASSERT(num_cols % num_channels == 0);
119  Matrix<BaseFloat> original_image(*image);
120  for (int32 r = 0; r < width; r++) {
121  for (int32 c = 0; c < height; c++) {
122  // (r_old, c_old) is the coordinate of the pixel in the original image
123  // while (r, c) is the coordinate in the new (transformed) image.
124  BaseFloat r_old = transform(0, 0) * r +
125  transform(0, 1) * c + transform(0, 2);
126  BaseFloat c_old = transform(1, 0) * r +
127  transform(1, 1) * c + transform(1, 2);
128  // We are going to do bilinear interpolation between 4 closest points
129  // to the point (r_old, c_old) of the original image. We have:
130  // r1 <= r_old <= r2
131  // c1 <= c_old <= c2
132  int32 r1 = static_cast<int32>(floor(r_old));
133  int32 c1 = static_cast<int32>(floor(c_old));
134  int32 r2 = r1 + 1;
135  int32 c2 = c1 + 1;
136 
137  // These weights determine how much each of the 4 points contributes
138  // to the final interpolated value:
139  BaseFloat weight_11 = (r2 - r_old) * (c2 - c_old),
140  weight_12 = (r2 - r_old) * (c_old - c1),
141  weight_21 = (r_old - r1) * (c2 - c_old),
142  weight_22 = (r_old - r1) * (c_old - c1);
143  // Handle edge conditions:
144  if (fill_mode == kNearest) {
145  if (r1 < 0) {
146  r1 = 0;
147  if (r2 < 0) r2 = 0;
148  }
149  if (r2 >= width) {
150  r2 = width - 1;
151  if (r1 >= width) r1 = width - 1;
152  }
153  if (c1 < 0) {
154  c1 = 0;
155  if (c2 < 0) c2 = 0;
156  }
157  if (c2 >= height) {
158  c2 = height - 1;
159  if (c1 >= height) c1 = height - 1;
160  }
161  } else {
162  KALDI_ASSERT(fill_mode == kReflect);
163  if (r1 < 0) {
164  r1 = - r1;
165  if (r2 < 0) r2 = - r2;
166  }
167  if (r2 >= width) {
168  r2 = 2 * width - 2 - r2;
169  if (r1 >= width) r1 = 2 * width - 2 - r1;
170  }
171  if (c1 < 0) {
172  c1 = - c1;
173  if (c2 < 0) c2 = -c2;
174  }
175  if (c2 >= height) {
176  c2 = 2 * height - 2 - c2;
177  if (c1 >= height) c1 = 2 * height - 2 - c1;
178  }
179  }
180  for (int32 ch = 0; ch < num_channels; ch++) {
181  // find the values at the 4 points
182  BaseFloat p11 = original_image(r1, num_channels * c1 + ch),
183  p12 = original_image(r1, num_channels * c2 + ch),
184  p21 = original_image(r2, num_channels * c1 + ch),
185  p22 = original_image(r2, num_channels * c2 + ch);
186  (*image)(r, num_channels * c + ch) = weight_11 * p11 + weight_12 * p12 +
187  weight_21 * p21 + weight_22 * p22;
188  }
189  }
190  }
191 }
192 
206  MatrixBase<BaseFloat> *image) {
207  config.Check();
208  FillMode fill_mode = config.GetFillMode();
209  int32 image_width = image->NumRows(),
210  num_channels = config.num_channels,
211  image_height = image->NumCols() / num_channels;
212  if (image->NumCols() % num_channels != 0) {
213  KALDI_ERR << "Number of columns in image must divide the number "
214  "of channels";
215  }
216  // We do an affine transform which
217  // handles flipping, translation, rotation, magnification, and shear.
218  Matrix<BaseFloat> transform_mat(3, 3, kUndefined);
219  transform_mat.SetUnit();
220 
221  Matrix<BaseFloat> shift_mat(3, 3, kUndefined);
222  shift_mat.SetUnit();
223  // translation (shift) mat:
224  // [ 1 0 x_shift
225  // 0 1 y_shift
226  // 0 0 1 ]
227  BaseFloat horizontal_shift = (2.0 * RandUniform() - 1.0) *
228  config.horizontal_shift * image_width;
229  BaseFloat vertical_shift = (2.0 * RandUniform() - 1.0) *
230  config.vertical_shift * image_height;
231  shift_mat(0, 2) = round(horizontal_shift);
232  shift_mat(1, 2) = round(vertical_shift);
233  // since we will center the image before applying the transform,
234  // horizontal flipping is simply achieved by setting [0, 0] to -1:
235  if (WithProb(config.horizontal_flip_prob))
236  shift_mat(0, 0) = -1.0;
237 
238  Matrix<BaseFloat> rotation_mat(3, 3, kUndefined);
239  rotation_mat.SetUnit();
240  // rotation mat:
241  // [ cos(theta) -sin(theta) 0
242  // sin(theta) cos(theta) 0
243  // 0 0 1 ]
244  if (RandUniform() <= config.rotation_prob) {
245  BaseFloat theta = (2 * config.rotation_degree * RandUniform() -
246  config.rotation_degree) / 180.0 * M_PI;
247  rotation_mat(0, 0) = cos(theta);
248  rotation_mat(0, 1) = -sin(theta);
249  rotation_mat(1, 0) = sin(theta);
250  rotation_mat(1, 1) = cos(theta);
251  }
252 
253  Matrix<BaseFloat> shear_mat(3, 3, kUndefined);
254  shear_mat.SetUnit();
255  // shear mat:
256  // [ 1 -sin(shear) 0
257  // 0 cos(shear) 0
258  // 0 0 1 ]
259 
260  Matrix<BaseFloat> zoom_mat(3, 3, kUndefined);
261  zoom_mat.SetUnit();
262  // zoom mat:
263  // [ x_zoom 0 0
264  // 0 y_zoom 0
265  // 0 0 1 ]
266 
267  // transform_mat = rotation_mat * shift_mat * shear_mat * zoom_mat:
268  transform_mat.AddMatMat(1.0, shift_mat, kNoTrans,
269  shear_mat, kNoTrans, 0.0);
270  transform_mat.AddMatMatMat(1.0, rotation_mat, kNoTrans,
271  transform_mat, kNoTrans,
272  zoom_mat, kNoTrans, 0.0);
273  if (transform_mat.IsUnit()) // nothing to do
274  return;
275 
276  // we should now change the origin of transform to the center of
277  // the image (necessary for flipping, zoom, shear, and rotation)
278  // we do this by using two translations: one before the main transform
279  // and one after.
280  Matrix<BaseFloat> set_origin_mat(3, 3, kUndefined);
281  set_origin_mat.SetUnit();
282  set_origin_mat(0, 2) = image_width / 2.0 - 0.5;
283  set_origin_mat(1, 2) = image_height / 2.0 - 0.5;
284  Matrix<BaseFloat> reset_origin_mat(3, 3, kUndefined);
285  reset_origin_mat.SetUnit();
286  reset_origin_mat(0, 2) = -image_width / 2.0 + 0.5;
287  reset_origin_mat(1, 2) = -image_height / 2.0 + 0.5;
288 
289  // transform_mat = set_origin_mat * transform_mat * reset_origin_mat
290  transform_mat.AddMatMatMat(1.0, set_origin_mat, kNoTrans,
291  transform_mat, kNoTrans,
292  reset_origin_mat, kNoTrans, 0.0);
293  ApplyAffineTransform(transform_mat, config.num_channels, image, fill_mode);
294 }
295 
296 
303  const ImageAugmentationConfig &config,
304  NnetExample *eg) {
305  int32 io_size = eg->io.size();
306  bool found_input = false;
307  for (int32 i = 0; i < io_size; i++) {
308  NnetIo &io = eg->io[i];
309  if (io.name == "input") {
310  found_input = true;
311  Matrix<BaseFloat> image;
312  io.features.GetMatrix(&image);
313  // note: 'GetMatrix' may uncompress if it was compressed.
314  // We won't recompress, but this won't matter because this
315  // program is intended to be used as part of a pipe, we
316  // likely won't be dumping the perturbed data to disk.
317  PerturbImage(config, &image);
318 
319  // modify the 'io' object.
320  io.features = image;
321  }
322  }
323  if (!found_input)
324  KALDI_ERR << "Nnet example to perturb had no NnetIo object named 'input'";
325 }
326 
327 
328 } // namespace nnet3
329 } // namespace kaldi
330 
331 int main(int argc, char *argv[]) {
332  try {
333  using namespace kaldi;
334  using namespace kaldi::nnet3;
335  typedef kaldi::int32 int32;
336  typedef kaldi::int64 int64;
337 
338  const char *usage =
339  "Copy examples (single frames or fixed-size groups of frames) for neural\n"
340  "network training, doing image augmentation inline (copies after possibly\n"
341  "modifying of each image, randomly chosen according to configuration\n"
342  "parameters).\n"
343  "E.g.:\n"
344  " nnet3-egs-augment-image --horizontal-flip-prob=0.5 --horizontal-shift=0.1\\\n"
345  " --vertical-shift=0.1 --srand=103 --num-channels=3 --fill-mode=nearest ark:- ark:-\n"
346  "\n"
347  "Requires that each eg contain a NnetIo object 'input', with successive\n"
348  "'t' values representing different x offsets , and the feature dimension\n"
349  "representing the y offset and the channel (color), with the channel\n"
350  "varying the fastest.\n"
351  "See also: nnet3-copy-egs\n";
352 
353 
354  int32 srand_seed = 0;
355 
357 
358  ParseOptions po(usage);
359  po.Register("srand", &srand_seed, "Seed for the random number generator");
360 
361  config.Register(&po);
362 
363  po.Read(argc, argv);
364 
365  srand(srand_seed);
366 
367  if (po.NumArgs() < 2) {
368  po.PrintUsage();
369  exit(1);
370  }
371 
372 
373  std::string examples_rspecifier = po.GetArg(1),
374  examples_wspecifier = po.GetArg(2);
375 
376  SequentialNnetExampleReader example_reader(examples_rspecifier);
377  NnetExampleWriter example_writer(examples_wspecifier);
378 
379 
380  int64 num_done = 0;
381  for (; !example_reader.Done(); example_reader.Next(), num_done++) {
382  std::string key = example_reader.Key();
383  NnetExample eg(example_reader.Value());
384  PerturbImageInNnetExample(config, &eg);
385  example_writer.Write(key, eg);
386  }
387  KALDI_LOG << "Perturbed " << num_done << " neural-network training images.";
388  return (num_done == 0 ? 1 : 0);
389  } catch(const std::exception &e) {
390  std::cerr << e.what() << '\n';
391  return -1;
392  }
393 }
NnetExample is the input data and corresponding label (or labels) for one or more frames of input...
Definition: nnet-example.h:111
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
float RandUniform(struct RandomState *state=NULL)
Returns a random number strictly between 0 and 1.
Definition: kaldi-math.h:151
void GetMatrix(Matrix< BaseFloat > *mat) const
Outputs the contents as a matrix.
#define M_PI
Definition: kaldi-math.h:44
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
int main(int argc, char *argv[])
Base class which provides matrix operations not involving resizing or allocation. ...
Definition: kaldi-matrix.h:49
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
bool WithProb(BaseFloat prob, struct RandomState *state)
Definition: kaldi-math.cc:72
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
GeneralMatrix features
The features or labels.
Definition: nnet-example.h:46
void SetUnit()
Sets to zero, except ones along diagonal [for non-square matrices too].
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
void PerturbImageInNnetExample(const ImageAugmentationConfig &config, NnetExample *eg)
This function does image perturbation as directed by &#39;config&#39; The example &#39;eg&#39; is expected to contain...
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void AddMatMat(const Real alpha, const MatrixBase< Real > &A, MatrixTransposeType transA, const MatrixBase< Real > &B, MatrixTransposeType transB, const Real beta)
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
void PerturbImage(const ImageAugmentationConfig &config, MatrixBase< BaseFloat > *image)
This function randomly modifies (perturbs) the image by applying different geometric transformations ...
void AddMatMatMat(const Real alpha, const MatrixBase< Real > &A, MatrixTransposeType transA, const MatrixBase< Real > &B, MatrixTransposeType transB, const MatrixBase< Real > &C, MatrixTransposeType transC, const Real beta)
this <– beta*this + alpha*A*B*C.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void ApplyAffineTransform(MatrixBase< BaseFloat > &transform, int32 num_channels, MatrixBase< BaseFloat > *image, FillMode fill_mode)
This function applies a geometric transformation &#39;transform&#39; to the image.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
std::string name
the name of the input in the neural net; in simple setups it will just be "input".
Definition: nnet-example.h:36
std::vector< NnetIo > io
"io" contains the input and output.
Definition: nnet-example.h:116
bool IsUnit(Real cutoff=1.0e-05) const
Returns true if the matrix is all zeros, except for ones on diagonal.
#define KALDI_LOG
Definition: kaldi-error.h:153