Ondrian Ondrian - 28 days ago 9
Python Question

Cython parallel loop problems

I am using cython to compute a pairwise distance matrix using a custom metric as a faster alternative to scipy.spatial.distance.pdist.

My Motivation



My metric has the form

def mymetric(u,v,w):
np.sum(w * (1 - np.abs(np.abs(u - v) / np.pi - 1))**2)


and the pairwise distance using scipy can be computed as

x = sp.spatial.distance.pdist(r, metric=lambda u, v: mymetric(u, v, w))


Here,
r
is a
m
-by-
n
matrix of
m
vectors with dimension of
n
and
w
is a "weight" factor with dimmension
n
.

Since in my problem
m
is rather high, the computation is really slow. For
m = 2000
and
n = 10
this takes approx 20 sec.

Initial solution with Cython



I implemented a simple function in cython that computes the pairwise distance and immediately got very promising results -- speedup of over 500x.

import numpy as np
cimport numpy as np
import cython

from libc.math cimport fabs, M_PI

@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size
cdef np.ndarray[np.double_t, ndim=1] ans
size = r.shape[0] * (r.shape[0] - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
c = -1
for i in range(r.shape[0]):
for j in range(i + 1, r.shape[0]):
c += 1
for k in range(r.shape[1]):
ans[c] += w[k] * (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))**2.0

return ans


Problems using OpenMP



I wanted to speed up the computation some more using OpenMP, however, the following solution is roughly 3 times slower than the serial version.

import numpy as np
cimport numpy as np

import cython
from cython.parallel import prange, parallel

cimport openmp

from libc.math cimport fabs, M_PI

@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size, m, n
cdef np.double_t a
cdef np.ndarray[np.double_t, ndim=1] ans
m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
with nogil, parallel(num_threads=8):
for i in prange(m, schedule='dynamic'):
for j in range(i + 1, m):
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
for k in range(n):
ans[c] += w[k] * (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))**2.0

return ans


I don't know why is it actually slower, but I tried to introduce the following changes. This resulted not only in even slightly worse performance but also, the resulting distance
ans
is computed correctly only in the beginning of the array, the rest is just zeros.
The speedup achieved through this is negligible.

import numpy as np
cimport numpy as np

import cython
from cython.parallel import prange, parallel

cimport openmp

from libc.math cimport fabs, M_PI
from libc.stdlib cimport malloc, free

@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp_2(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int k, l, c, m, n
cdef Py_ssize_t i, j, d
cdef size_t size
cdef int *ci, *cj

cdef np.ndarray[np.double_t, ndim=1, mode="c"] ans

cdef np.ndarray[np.double_t, ndim=2, mode="c"] data
cdef np.ndarray[np.double_t, ndim=1, mode="c"] weight

data = np.ascontiguousarray(r, dtype=np.float64)
weight = np.ascontiguousarray(w, dtype=np.float64)

m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)

cj = <int*> malloc(size * sizeof(int))
ci = <int*> malloc(size * sizeof(int))

c = -1
for i in range(m):
for j in range(i + 1, m):
c += 1
ci[c] = i
cj[c] = j

with nogil, parallel(num_threads=8):
for d in prange(size, schedule='guided'):
for k in range(n):
ans[d] += weight[k] * (1.0 - fabs(fabs(data[ci[d], k] - data[cj[d], k]) / M_PI - 1.0))**2.0

return ans


For all functions, I am using the following
.pyxbld
file

def make_ext(modname, pyxfilename):
from distutils.extension import Extension
return Extension(name=modname,
sources=[pyxfilename],
extra_compile_args=['-O3', '-march=native', '-ffast-math', '-fopenmp'],
extra_link_args=['-fopenmp'],
)


Summary



I have zero experience with cython and know only basics of C. I would appreciate any suggestion of what may be the cause of this unexpected behavior, or even, how to rephrase my question better.




Best serial solution (10 % faster than original serial)



@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_2(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size
cdef np.ndarray[np.double_t, ndim=1] ans
cdef np.double_t accumulator, tmp
size = r.shape[0] * (r.shape[0] - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
c = -1
for i in range(r.shape[0]):
for j in range(i + 1, r.shape[0]):
c += 1
accumulator = 0
for k in range(r.shape[1]):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
accumulator += w[k] * (tmp*tmp)
ans[c] = accumulator

return ans


Best parallel solution (1 % faster then original parallel, 6 times faster then best serial using 8 threads)



@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp_2d(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size, m, n
cdef np.ndarray[np.double_t, ndim=1] ans
cdef np.double_t accumulator, tmp
m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
with nogil, parallel(num_threads=8):
for i in prange(m, schedule='dynamic'):
for j in range(i + 1, m):
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
accumulator = 0
for k in range(n):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
ans[c] += w[k] * (tmp*tmp)

return ans





Unsolved issues:



When I try to apply the
accumulator
solution proposed in the answer, I get the following error:

Error compiling Cython file:
------------------------------------------------------------
...
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
accumulator = 0
for k in range(n):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
accumulator += w[k] * (tmp*tmp)
ans[c] = accumulator
^
------------------------------------------------------------
pdist.pyx:207:36: Cannot read reduction variable in loop body


Full code:

@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size, m, n
cdef np.ndarray[np.double_t, ndim=1] ans
cdef np.double_t accumulator, tmp
m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
with nogil, parallel(num_threads=8):
for i in prange(m, schedule='dynamic'):
for j in range(i + 1, m):
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
accumulator = 0
for k in range(n):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
accumulator += w[k] * (tmp*tmp)
ans[c] = accumulator

return ans

Answer

I haven't timed this myself so it's possible this might not help too much, however:

If you run cython -a to get an annotated version of your initial attempt (pairwise_distance_omp) you'll find the ans[c] += ... line is yellow, suggesting it's got Python overhead. A look at that the C corresponding to that line suggests that it's checking for divide by zero. One key part of it starts:

if (unlikely(M_PI == 0)) {

You know this will never be true (and in any case you'd probably live with NaN values rather than an exception if it was). You can avoid this check by adding the following extra decorator to the function:

@cython.cdivision(True)
# other decorators
def pairwise_distance_omp # etc...

This cuts out quite a bit of C code, including bits that have to be run in a single thread. The flip-side is that most of that code should never be run, and the compiler should probably be able to work that out, so it isn't clear how much difference that will make.


Second suggestion:

# at the top
cdef np.double_t accumulator, tmp

    # further down later in the loop:
    c = i * (m - 1) - i * (i + 1) / 2 + j - 1
    accumulator = 0
    for k in range(r.shape[1]):
        tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
        accumulator += w[k] * (tmp*tmp)
    ans[c] = accumulator

This has two advantages hopefully: 1) tmp*tmp should probably be quicker than floating point exponent to the power of 2. 2) You avoid reading from the ans array, which might be a bit slow because the compiler always has to be careful that some other thread hasn't changed it (even though you know it shouldn't have).

Comments