jrabary - 1 year ago 272

Python Question

I want to compute the pairwise square distance of a batch of feature in Tensorflow. I have a simple implementation using + and * operations by

tiling the original tensor :

`def pairwise_l2_norm2(x, y, scope=None):`

with tf.op_scope([x, y], scope, 'pairwise_l2_norm2'):

size_x = tf.shape(x)[0]

size_y = tf.shape(y)[0]

xx = tf.expand_dims(x, -1)

xx = tf.tile(xx, tf.pack([1, 1, size_y]))

yy = tf.expand_dims(y, -1)

yy = tf.tile(yy, tf.pack([1, 1, size_x]))

yy = tf.transpose(yy, perm=[2, 1, 0])

diff = tf.sub(xx, yy)

square_diff = tf.square(diff)

square_dist = tf.reduce_sum(square_diff, 1)

return square_dist

This function takes as input two matrices of size (m,d) and (n,d) and compute the squared distance between each row vector. The output is a matrix of size (m,n) with element 'd_ij = dist(x_i, y_j)'.

The problem is that I have a large batch and high dim features 'm, n, d' replicating the tensor consume a lot of memory.

I'm looking for another way to implement this without increasing the memory usage and just only store the final distance tensor. Kind of double looping the original tensor.

Recommended for you: Get network issues from **WhatsUp Gold**. **Not end users.**

Answer Source

You can use some linear algebra to turn it into matrix ops. Note that what you need matrix `D`

where `a[i]`

is the `i`

th row of your original matrix and

```
D[i,j] = (a[i]-a[j])(a[i]-a[j])'
```

You can rewrite that into

```
D[i,j] = r[i] - 2 a[i]a[j]' + r[j]
```

Where `r[i]`

is squared norm of `i`

th row of the original matrix.

In a system that supports standard broadcasting rules you can treat `r`

as a column vector and write `D`

as

```
D = r - 2 A A' + r'
```

In TensorFlow you could write this as

```
A = tf.constant([[1, 1], [2, 2], [3, 3]])
r = tf.reduce_sum(A*A, 1)
# turn r into column vector
r = tf.reshape(r, [-1, 1])
D = r - 2*tf.matmul(A, tf.transpose(A)) + tf.transpose(r)
sess = tf.Session()
sess.run(D)
```

result

```
array([[0, 2, 8],
[2, 0, 2],
[8, 2, 0]], dtype=int32)
```

Recommended from our users: **Dynamic Network Monitoring from WhatsUp Gold from IPSwitch**. ** Free Download**