mdaoust - 3 months ago 135
Python Question

Tensorflow: How to replace a node in a calculation graph?

If you have two disjoint graphs, and want to link them, turning this:

``````x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)
``````

into this:

``````x = tf.placeholder('float')
y = f(x)
z = g(y)
``````

Is there a way to do that? It seems like it could make construction easier in some cases.

For example if you have a graph that has the input image as a
`tf.placeholder`
, and want to optimize the input image, deep-dream style, is there a way to just replace the placeholder with a
`tf.variable`
node? Or do you have to think of that before building the graph?

TL;DR: If you can define the two computations as Python functions, you should do that. If you can't, there's more advanced functionality in TensorFlow to serialize and import graphs, which allows you to compose graphs from different sources.

One way to do this in TensorFlow is to build the disjoint computations as separate `tf.Graph` objects, then convert them to serialized protocol buffers using `Graph.as_graph_def()`:

``````with tf.Graph().as_default() as g_1:
input = tf.placeholder(tf.float32, name="input")
y = f(input)
# NOTE: using identity to get a known name for the output tensor.
output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1
input = tf.placeholder(tf.float32, name="input")
z = g(input)
output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()
``````

Then you could compose `gdef_1` and `gdef_2` into a third graph, using `tf.import_graph_def()`:

``````with tf.Graph().as_default() as g_combined:
x = tf.placeholder(tf.float32, name="")

# Import gdef_1, which performs f(x).
# "input:0" and "output:0" are the names of tensors in gdef_1.
y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
return_elements=["output:0"])

# Import gdef_2, which performs g(y)
z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
return_elements=["output:0"]
``````