Luca Luca - 3 months ago 18
Python Question

Extract hyper-cubical blocks from a numpy array with unknown number of dimensions

I have a bit of python code which currently is hard-wired with two-dimensional arrays as follows:

import numpy as np
data = np.random.rand(5, 5)
width = 3

for y in range(0, data.shape[1] - W + 1):
for x in range(0, data.shape[0] - W + 1):
block = data[x:x+W, y:y+W]
# Do something with this block


Now, this is hard coded for a 2-dimensional array and I would like to extend this to 3D and 4D arrays. I could, of course, write more functions for other dimensions but I was wondering if there is a python/numpy trick to generate these sub-blocks without having to replicate this function for multidimensional data.

Answer

Here is my wack at this problem. The idea behind the code below is to find the "starting indices" for each slice of data. So for 4x4x4 sub-arrays of a 5x5x5 array, the starting indices would be (0,0,0), (0,0,1), (0,1,0), (0,1,1), (1,0,0), (1,0,1), (1,1,1), and the slices along each dimension would be of length 4.

To get the sub-arrays, you just need to iterate over the different tuples of slice objects and pass them to the array.

import numpy as np
from itertools import product

def iterslice(data_shape, width):
    # check for invalid width
    assert(all(sh>=width for sh in data_shape), 
           'all axes lengths must be at least equal to width')

    # gather all allowed starting indices for the data shape
    start_indices = [range(sh-width+1) for sh in data_shape]

    # create tuples of all allowed starting indices
    start_coords = product(*start_indices)

    # iterate over tuples of slice objects that have the same dimension
    # as data_shape, to be passed to the vector
    for start_coord in start_coords:
        yield tuple(slice(coord, coord+width) for coord in start_coord)

# create 5x5x5 array
arr = np.arange(0,5**3).reshape(5,5,5)

# create the data slice tuple iterator for 3x3x3 sub-arrays
data_slices = iterslice(arr.shape, 3)

# the sub-arrays are a list of 3x3x3 arrays, in this case
sub_arrays = [arr[ds] for ds in data_slices]