Ryan Chesler Ryan Chesler - 5 months ago 34
Python Question

Loading metagraph and checkpoints in tensorflow

I have been working on this for a while now and can't seem to crack it. In other questions I have seen them use these code samples in order to save and restore a model using the metagraph and checkpoint files, but when I do something similar to this it says that

w1
is undefined when I have the savemodel and restore model as separate python files. It works ok when I just have the restore at the end of the saving portion but it defeats the purpose to have to hand define everything all over again in a seperate file. I have looked into the checkpoint file and it seems bizarre that it only has two lines and it doesnt seem to reference any variables or have any values. it is only 1kb. I have tried putting in 'w1' as a string in the print function instead and that returns a None rather than the values I am looking for. Does this work for anyone else? if so, what do your checkpoint files look like?

#Saving
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)

#restoring
with tf.Session() as sess:
saver = tf.train.import_meta_graph('my_test_model-1000.meta',clear_devices=True)
saver.restore(sess,tf.train.latest_checkpoint('./'))
print sess.run(w1)

Answer Source

Your graph is saved correctly, but restoring it does not restore your variables that contain nodes of the graph. w1 is a python variable that you've never declared in you 'restoring' part of the code. To get back a handle on your weights,

  • you can use their names in the TF graph: w1=get_variable(name='w1'). The problem is that you'll have to pay close attention to your name scopes, and make sure that you don't have multiple variables of the same name (in which case TF adds '_1' to one of their names, so you might get the wrong one). If you go that way, tensorboard can be of great help to know the exact name of each variable.

  • You can use collections: save the interesting nodes in collections, and get them back from them after restoring. When building the graph, before saving it, do for instance: tf.add_to_collection('weights', w1) and tf.add_to_collection('weights', w2), and in your restoring code: [w1, w2] = tf.get_collection('weights1'). Then you'll be able to use w1 and w2 normally.

I think the latter, though more verbose, is probably better with regard to future changes in your architecture. I know all of this looks quite verbose, but remember that usually you don't have to get back handles on all your variables, but on few of them: the inputs, outputs, and train step are usually enough.