mathetes mathetes -4 years ago 179
Python Question

Tensorflow: How to save/restore a model? (python)

After you train a model in Tensorflow:


  1. How do you save the trained model?

  2. How do you later restore this saved model?






Nov 16: Updated selected answer



For tensorflow version previous to 0.11.0RC1, see Ryan Sepassi's answer




Feb 17: Reformulated question to be more concise


Answer Source

In( and After) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to https://www.tensorflow.org/programmers_guide/meta_graph

save model:

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

restore model:

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download