I'm working on implementing a CNN architecture (FCN-8s model, with pretrained VGG16 model) for semantic segmentation on my own data (2 classes, therefore, a binary per-pixel classification)
How I intend to go about this is:
[batch, h, w, num_classes] = [1, 750, 750, 2]
logits = vgg_fcn.deconv_1
stopper = tf.stop_gradient(logits, 'stop_gradients')
loss = train_func.loss(stopper, labels_placeholder, 2)
train_op = train_func.training(loss, FLAGS.learning_rate)
eval_correct = train_func.accuracy_eval(logits, labels_placeholder)
accuracy_summary = tf.scalar_summary('Accuracy', eval_correct)
_, acc, loss_value = sess.run([train_op,eval_correct, loss], feed_dict=feed_dict)
def loss(logits, labels, num_classes):
logits = tf.reshape(logits, [-1, num_classes])
#epsilon = tf.constant(value=1e-4)
#logits = logits + epsilon
labels = tf.to_int64(tf.reshape(labels, [-1]))
print ('shape of logits: %s' % str(logits.get_shape()))
print ('shape of labels: %s' % str(labels.get_shape()))
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='Cross_Entropy')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='xentropy_mean')
loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
You could just pass the output of the pretrained model into sess.run(pretrained_output, ...) and capture the output of the pretrained model. After you save the output you could then feed it into your model. In this case the optimizer would not be able to propagate the gradients to the pretrained model.
You could also attach the pre trained model to you model normally and then pass the pretrained output through tf.stop_graidents() which would prevent the optimizer from propagating the gradients back into the pretrained model.
Finally, you could just go through all the variables in the pretrained model and remove them from the list of trainable variables.