YW P Kwon YW P Kwon - 2 months ago 72
Python Question

In Tensorflow, how to use tf.gather() for the last dimension?

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth], I want to select slices based on the last dimension, such as

# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]

However, tf.gather(L, [0, 2,3,8]) seems to only work for the first dimension (right?) Can anyone tell me how to do it?


There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206

For now you can:

  1. transpose your matrix so that dimension to gather is first (transpose is expensive)

  2. reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back

  3. use gather_nd. Will still need to turn your column indices into list of individual element indices.