YW P Kwon - 3 months ago 104

Python Question

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth], I want to select slices based on the last dimension, such as

`# L is intermediate tensor`

partL = L[:, :, :, [0,2,3,8]]

However, tf.gather(L, [0, 2,3,8]) seems to only work for the first dimension (right?) Can anyone tell me how to do it?

Answer

There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206

For now you can:

transpose your matrix so that dimension to gather is first (transpose is expensive)

reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back

- use
`gather_nd`

. Will still need to turn your column indices into list of individual element indices.