I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth], I want to select slices based on the last dimension, such as
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206
For now you can:
transpose your matrix so that dimension to gather is first (transpose is expensive)
reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back
gather_nd. Will still need to turn your column indices into list of individual element indices.