Vinny M Vinny M - 1 year ago 225
Python Question

Weights and Bias from Trained Meta Graph

I have successfully exported a re-trained InceptionV3 NN as a TensorFlow meta graph. I have read this protobuf back into python successfully, but I am struggling to see a way to export each layers weight and bias values, which I am assuming is stored within the meta graph protobuf, for recreating the nn outside of TensorFlow.

My workflow is as such:

Retrain final layer for new categories
Export meta graph tf.train.export_meta_graph(filename='model.meta')
Build python using Protoc and meta_graph.proto
Load Protobuf:

import meta_graph_pb2
saved = meta_graph_pb2.CollectionDef()
with open('model.meta', 'rb') as f:

From here I can view most aspects of the graph, like node names and such, but I think my inexperience is making it difficult to track down the correct way to access the weight and bias values for each relevant layer.

Answer Source

The MetaGraphDef proto doesn't actually contain the values of the weights and biases. Instead it provides a way to associate a GraphDef with the weights stored in one or more checkpoint files, written by a tf.train.Saver. The MetaGraphDef tutorial has more details, but the approximate structure is as follows:

  1. In you training program, write out a checkpoint using a tf.train.Saver. This will also write a MetaGraphDef to a .meta file in the same directory.

    saver = tf.train.Saver(...)
    # ..., "model")

    You should find files called model.meta and model-NNNN (for some integer NNNN) in your checkpoint directory.

  2. In another program, you can import the MetaGraphDef you just created, and restore from a checkpoint.

    saver = tf.train.import_meta_graph("model.meta")
    saver.restore("model-NNNN")  # Or whatever checkpoint filename was written.

    If you want to get the value of each variable, you can (for example) find the variable in tf.all_variables() collection and pass it to to get its value. For example, to print the values of all variables, you can do the following:

    for var in tf.all_variables():

    You could also filter tf.all_variables() to find the particular weights and biases that you're trying to extract from the model.