Fedal Fedal - 7 months ago 37
Python Question

In Tensorflow, can I use tf.gather() for partial connection?

I am trying to implement partial connection between layers. Let say, I want to use only some of feature maps, e.g., first and third one.


  • Is it correct to use tf.gather() for this purpose?

  • Can I just use indexing operator [ ] instead of tf.gather() as below?

  • Will gathering indices work in terms of back propagation? It is hard for me to imagine how Tensorflow will know internally the connection was from the first and third (which info is hard-coded) in the internal back-prop process. Does the function tf.gather remember the connection?



code:

# let say, L1 is layer1 output of shape [batch_size x image_size x image_size x depth1]
partL1 = L1[:, :, :, [0,2]]
# W2 is a tf variable of shape [5, 5, 2, depth2]
conv2 = tf.nn.conv2d(partL1, W2)

dga dga
Answer

Yes, no, yes. :-) (a) Yes, you can use gather to pick a subset of a layer to propagate to the next layer, as you suggested.

(b) No, you can't use the indexing operator, unfortunately. You need to explicitly invoke tf.gather().

(c) Yes, TensorFlow will stash a copy of the indices used for gathering and save them for backprop. You can see the implementation of Gather's Gradient if you're curious about how - it looks at the inputs to the op and propagates back using those.

Comments