Clash Clash - 9 months ago 43
Python Question

Tensorflow how to check if a tensor row is only zeroes?

I'm training a simple network for predicting bounding box coordinates of a single object. There are however pictures where there is no object to be found. Since the network always makes a prediction, it also predicts a confidence value between 0 and 1 which should indicate the probability that there is an object in the picture. My tensor with the predictions is called

and it's a
(batch_size, 5)
tensor (confidence, x, y, width and height). Similarly the
tensor is also
(batch_size, 5)

Previously I was training only with images that always had an object, so I could basically do

loss = tf.l2_loss(logits - labels)

I want to start training also with pictures with no objects and when there is no object in the picture, I don't want the network to be penalized for whichever coordinates it predicted. In this case all that matters is the confidence value, which should be close to 0 (no object).

How should I structure my labels and loss function to accomplish this? I can set the label of images with no objects to all zeroes, but how do I check that a particular row is only zeroes? And in that case, the corresponding row in the logits needs also to be set to zeroes (except the confidence value!) so that the loss incurred because of the coordinates is also zero.


An approach to find if a row of a tensor is all zero:

import tensorflow as tf

image = tf.fill([8,8], 0)
sess = tf.Session()
image_row = tf.slice(image, [1,0], [1, -1])
total = tf.reduce_sum(tf.abs(image_row))
is_all_zero = tf.equal(total, 0)
print[total, is_all_zero, image_row])


[0, True, array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)]