Karthik Karthik - 11 months ago 53
Python Question

Get intersecting rows across two 2D numpy arrays

I want to get the intersecting (common) rows across two 2D numpy arrays. E.g., if the following arrays are passed as inputs:

array([[1, 4],
[2, 5],
[3, 6]])

array([[1, 4],
[3, 6],
[7, 8]])

the output should be:

array([[1, 4],
[3, 6])

I know how to do this with loops. I'm looking at a Pythonic/Numpy way to do this.

Answer Source

For short arrays, using sets is probably the clearest and most readable way to do it.

Another way is to use numpy.intersect1d. You'll have to trick it into treating the rows as a single value, though... This makes things a bit less readable...

import numpy as np

A = np.array([[1,4],[2,5],[3,6]])
B = np.array([[1,4],[3,6],[7,8]])

nrows, ncols = A.shape
dtype={'names':['f{}'.format(i) for i in range(ncols)],
       'formats':ncols * [A.dtype]}

C = np.intersect1d(A.view(dtype), B.view(dtype))

# This last bit is optional if you're okay with "C" being a structured array...
C = C.view(A.dtype).reshape(-1, ncols)

For large arrays, this should be considerably faster than using sets.