jbird jbird - 1 year ago 286
Python Question

Tensorflow Loading Model with Saver v2

I just updated my local installation of Tensorflow to 0.11rc2 and I got a message saying that I should add a parameter to my saver to make it save in version 2. I updated this and now I cannot load models that were saved in this format. When I run my model, it saves after every epoch. When it saves, it used to save files called

. Now I get three files instead of two, named
, and

To load data, I use the following code:

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
print("Created model with fresh parameters.")
return model

is a model object that was already initialized with the standard hyperparameters of my program. This works without issue with saver v1.
evaluates to the path to
regardless of version, so if the checkpoint was saved with v2, no file is found.

The contents of the
file in that directory (when saved with either version) are:

model_checkpoint_path: "translate.ckpt-3916"
all_model_checkpoint_paths: "translate.ckpt-3916"

Is there a new method to load data with saver v2? Otherwise, how can I load my checkpoints?

Changing the line
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
if ckpt and ckpt.model_checkpoint_path:
like is shown in this question seems to work a little further but then throws the following error:

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [84] rhs shape= [98]
[[Node: save/Assign_54 = Assign[T=DT_FLOAT, _class=["loc:@NLC/Logistic/Linear/Bias"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](NLC/Logistic/Linear/Bias, save/RestoreV2_54)]]

Answer Source

The method I posted in my edit was actually the correct way to get this to work. The error I got was because the data had changed between when I made the checkpoint and when I tried to load it.

Just to make it visible, loading from a V2 checkpoint in the code above was done by changing the line if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): to if ckpt and ckpt.model_checkpoint_path:

Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download