deeptigp deeptigp - 10 days ago 12
Python Question

Understanding tf.extract_image_patches for extracting patches from an image

I found the following method tf.extract_image_patches in tensorflow API but I am not clear about its functionality. Say the batch_size = 1, and image is of size 225X225X3, and we want to extract patches of size 32x32, how exactly does this function behave? Specifically, the documentation mentions the dimension of the output tensor to be [batch, out_rows, out_cols, ksize_rows * ksize_cols * depth] , but what out_rows and out_cols are is not mentioned.

Ideally, given an input image tensor of size 1x225x225x3 (where 1 is the batch size), I want to be able to get Kx32x32x3 as output, where K is the total number of patches and 32x32x3 is the dimension of each patch. Is there something in tensorflow that already achieves this?

Answer

Here is how the method works:

  • ksizes is used to decide the dimensions of each patch, or in other words, how many pixels each patch should contain.
  • strides denotes the length of the gap between the start of one patch and the start of the next consecutive patch within the original image.
  • rates is a number that essentially means our patch should jump by rates pixels in the original image for each consecutive pixel that ends up in our patch. (The example below helps illustrate this.)
  • padding is either "VALID", which means every patch must be fully contained in the image, or "SAME", which means patches are allowed to be incomplete (the remaining pixels will be filled in with zeroes).

Here is some sample code with output to help demonstrate how it works:

import tensorflow as tf

n = 10
# images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order
images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]

# We generate four outputs as follows:
# 1. 3x3 patches with stride length 5
# 2. Same as above, but the rate is increased to 2
# 3. 4x4 patches with stride length 7; only one patch should be generated
# 4. Same as above, but with padding set to 'SAME'
with tf.Session() as sess:
  print tf.extract_image_patches(images, [1, 3, 3, 1], [1, 5, 5, 1], [1, 1, 1, 1], 'VALID').eval(), '\n\n'
  print tf.extract_image_patches(images, [1, 3, 3, 1], [1, 5, 5, 1], [1, 2, 2, 1], 'VALID').eval(), '\n\n'
  print tf.extract_image_patches(images, [1, 4, 4, 1], [1, 7, 7, 1], [1, 1, 1, 1], 'VALID').eval(), '\n\n'
  print tf.extract_image_patches(images, [1, 4, 4, 1], [1, 7, 7, 1], [1, 1, 1, 1], 'SAME').eval()

Output:

[[[[ 1  2  3 11 12 13 21 22 23]
   [ 6  7  8 16 17 18 26 27 28]]

  [[51 52 53 61 62 63 71 72 73]
   [56 57 58 66 67 68 76 77 78]]]]


[[[[  1   3   5  21  23  25  41  43  45]
   [  6   8  10  26  28  30  46  48  50]]

  [[ 51  53  55  71  73  75  91  93  95]
   [ 56  58  60  76  78  80  96  98 100]]]]


[[[[ 1  2  3  4 11 12 13 14 21 22 23 24 31 32 33 34]]]]


[[[[  1   2   3   4  11  12  13  14  21  22  23  24  31  32  33  34]
   [  8   9  10   0  18  19  20   0  28  29  30   0  38  39  40   0]]

  [[ 71  72  73  74  81  82  83  84  91  92  93  94   0   0   0   0]
   [ 78  79  80   0  88  89  90   0  98  99 100   0   0   0   0   0]]]]

So, for example, our first result looks like the following:

 *  *  *  4  5  *  *  *  9 10 
 *  *  * 14 15  *  *  * 19 20 
 *  *  * 24 25  *  *  * 29 30 
31 32 33 34 35 36 37 38 39 40 
41 42 43 44 45 46 47 48 49 50 
 *  *  * 54 55  *  *  * 59 60 
 *  *  * 64 65  *  *  * 69 70 
 *  *  * 74 75  *  *  * 79 80 
81 82 83 84 85 86 87 88 89 90 
91 92 93 94 95 96 97 98 99 100 

As you can see, we have 2 rows and 2 columns worth of patches, which are what out_rows and out_cols are.

Comments