user2133814 user2133814 - 29 days ago 24
Python Question

GEMM using Numpy einsum

Can a single numpy einsum statement replicate gemm functionality? Scalar and matrix multiplication seem straightforward, but I haven't found how to get the "+" working. In case its simpler, D = alpha * A * B + beta * C would be acceptable (preferable actually)

alpha = 2
beta = 3
A = np.arange(9).reshape(3, 3)
B = A + 1
C = B + 1

left_part = alpha*np.dot(A, B)
print(left_part)
left_part = np.einsum(',ij,jk->ik', alpha, A, B)
print(left_part)

Answer

There seems to be some confusion here: np.einsum handles operations that can be cast in the following form: broadcast–multiply–reduce. Element-wise summation is not part of its scope.

The reason why you need this sort of thing for the multiplication is that writing these operations out "naively" may exceed memory or computing resources quickly. Consider, for example, matrix multiplication:

import numpy as np
x, y = np.ones((2, 2000, 2000))

# explicit loop - ridiculously slow
a = sum(x[:,j,np.newaxis] * y[j,:] for j in range(2000))

# explicit broadcast-multiply-reduce: throws MemoryError
a = (x[:,:,None] * y[:,None,:]).sum(1)

# einsum or dot: fast and memory-saving
a = np.einsum('ij,jk->ik', x, y)

The Einstein convention however factorizes for addition, so you can write your BLAS-like problem simply as:

d = np.einsum(',ij,jk->ik', alpha, a, b) + np.einsum(',ik', beta, c)

with minimal memory overhead (you can rewrite most of it as in-place operations if you are really concerned about memory) and constant runtime overhead (the cost of two python-to-C calls).

So regarding performance, this seems, respectfully, like a case of premature optimization to me: have you actually verified that the split of GEMM-like operations into two separate numpy calls is a bottleneck in your code? If it indeed is, then I suggest the following (in order of increasing involvedness):

  1. Try, carefully!, scipy.linalg.blas.dgemm. I would be surprised if you get significantly better performance, since dgemms are usually only building block themselves.

  2. Try a expression compiler (essentially you are proposing such a thing) like Theano.

  3. Write your own generalised ufunc using Cython or C.

Comments