murushiv murushiv - 1 month ago 91
Python Question

Tensorflow - transfer learning implementation (semantic segmentation)

I'm working on implementing a CNN architecture (FCN-8s model, with pretrained VGG16 model) for semantic segmentation on my own data (2 classes, therefore, a binary per-pixel classification)

How I intend to go about this is:


  1. Load the pre-trained model with weights

  2. Add/remove additional higher layers to convert to FCN

  3. Freeze lower layers of the pre-trained model (to not update during the training phase)

  4. Train the network on specific dataset



Assuming this is correct, how do I go about freezing the lower layers on my tensorflow model? (I'm looking for specific implementation details) I had a look at the Inception retraining on TensorFlow tutorial, but I'm not quite sure yet.

This is the workflow I have in mind:


  1. Run my data through the existing pretrained model, and extract the feature outputs, without training it. (how?)

  2. Feed these feature outputs into another network containing the higher layers - and go about training it.



Any suggestions would be helpful!

Else, if I'm wrong, how should I be thinking of this?

UPDATE:

I took up chasep255's suggestion below, and tried to use tf.stop_gradient so as to "freeze" the lower layers in my model. Clearly, there is something wrong with my implementation. Possible alternatives/suggestions?

The model is built based on the FCN (for semantic segmentation) paper. I extract
logits
from the model architecture, i.e., my features, that I initially feed directly into a
loss
function to minimize it with a softmax classifier. (per-pixel classification)
deconv_1
is my logits tensor, of shape
[batch, h, w, num_classes] = [1, 750, 750, 2]
Implementation:

logits = vgg_fcn.deconv_1

stopper = tf.stop_gradient(logits, 'stop_gradients')

loss = train_func.loss(stopper, labels_placeholder, 2)

with tf.name_scope('Optimizer'):
train_op = train_func.training(loss, FLAGS.learning_rate)

with tf.name_scope('Accuracy'):
eval_correct = train_func.accuracy_eval(logits, labels_placeholder)
accuracy_summary = tf.scalar_summary('Accuracy', eval_correct)


I then run these Graph operations as below:

_, acc, loss_value = sess.run([train_op,eval_correct, loss], feed_dict=feed_dict)


When I run the training cycle thus, there is no optimization of the loss value, most definitely because of how I've introduced the
tf.stop_gradient
Op.

For more details, my loss function below:

def loss(logits, labels, num_classes):

logits = tf.reshape(logits, [-1, num_classes])
#epsilon = tf.constant(value=1e-4)
#logits = logits + epsilon

labels = tf.to_int64(tf.reshape(labels, [-1]))
print ('shape of logits: %s' % str(logits.get_shape()))
print ('shape of labels: %s' % str(labels.get_shape()))

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='Cross_Entropy')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='xentropy_mean')
tf.add_to_collection('losses', cross_entropy_mean)

loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
return loss

Answer

You could just pass the output of the pretrained model into sess.run(pretrained_output, ...) and capture the output of the pretrained model. After you save the output you could then feed it into your model. In this case the optimizer would not be able to propagate the gradients to the pretrained model.

You could also attach the pre trained model to you model normally and then pass the pretrained output through tf.stop_graidents() which would prevent the optimizer from propagating the gradients back into the pretrained model.

Finally, you could just go through all the variables in the pretrained model and remove them from the list of trainable variables.

Comments