BernardoGO BernardoGO - 1 month ago 44
Python Question

Importing TensorFlow graph fails for uninitialized variables

I'm trying to export the multi layer perceptron example as a .pb graph.
In order to do it, I have named the input variables and output operation and added the following line:

tf.train.write_graph(sess.graph_def, "./", "graph.pb", False)


To import, I did the following:

with gfile.FastGFile("graph.pb",'rb') as f:

print("load graph")
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as persisted_sess:

persisted_result = persisted_sess.graph.get_tensor_by_name("output:0")
avd = persisted_sess.run(persisted_result, feed_dict={"input_x:0": features_t})
print ("Result:", str(avd))


It does import fine but throws an error for the "run" line.

Traceback (most recent call last):
File "/usr/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 972, in _do_call
return fn(*args)
File "/usr/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 954, in _run_fn
status, run_metadata)
File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/usr/lib/python3.5/site-packages/tensorflow/python/framework/errors.py", line 463, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value Variable_3
[[Node: Variable_3/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_3"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_3)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "teste.py", line 56, in <module>
avd = persisted_sess.run(persisted_result, feed_dict={"input_x:0": features_t})
File "/usr/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/usr/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 915, in _run
feed_dict_string, options, run_metadata)
File "/usr/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 965, in _do_run
target_list, options, run_metadata)
File "/usr/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 985, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value Variable_3
[[Node: Variable_3/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_3"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_3)]]

Caused by op 'Variable_3/read', defined at:
File "teste.py", line 37, in <module>
_ = tf.import_graph_def(graph_def, name='')
File "/usr/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 285, in import_graph_def
op_def=op_def)
File "/usr/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2380, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1298, in __init__
self._traceback = _extract_stack()

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_3
[[Node: Variable_3/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_3"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_3)]]


I have tried to initialize all variables but it does not work.

Answer

TensorFlow splits saving the Graph definition and the Variable values in different files (graph and checkpoint respectively).

You want to use the TF Saver.

See this answer for details: http://stackoverflow.com/a/33762168/4120005

Or the documentation here: https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html#saving-variables

If you really need to restore just from the graphdef file (*.pb), to load it from C++ for instance, you will need to use the freeze_graph.py script from here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

This script takes a graphdef (.pb) and a checkpoint (.ckpt) file as input and outputs a graphdef file which contains the weights in the form of constants (you can read the docs on the script for more details).