jbird jbird - 26 days ago 13
Python Question

Tensorflow Loading Model with Saver v2

I just updated my local installation of Tensorflow to 0.11rc2 and I got a message saying that I should add a parameter to my saver to make it save in version 2. I updated this and now I cannot load models that were saved in this format. When I run my model, it saves after every epoch. When it saves, it used to save files called

translate.ckpt-3916
and
translate.ckpt-3916.meta
. Now I get three files instead of two, named
translate.ckpt-3916.index
,
translate.ckpt-3916.meta
, and
translate.ckpt-3916.data-000000-of-000001
.

To load data, I use the following code:

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
session.run(tf.initialize_all_variables())
return model


Where
model
is a model object that was already initialized with the standard hyperparameters of my program. This works without issue with saver v1.
ckpt.model_checkpoint_path
evaluates to the path to
translate.ckpt-3916
regardless of version, so if the checkpoint was saved with v2, no file is found.

The contents of the
checkpoint
file in that directory (when saved with either version) are:

model_checkpoint_path: "translate.ckpt-3916"
all_model_checkpoint_paths: "translate.ckpt-3916"


Is there a new method to load data with saver v2? Otherwise, how can I load my checkpoints?

EDIT:
Changing the line
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
to
if ckpt and ckpt.model_checkpoint_path:
like is shown in this question seems to work a little further but then throws the following error:

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [84] rhs shape= [98]
[[Node: save/Assign_54 = Assign[T=DT_FLOAT, _class=["loc:@NLC/Logistic/Linear/Bias"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](NLC/Logistic/Linear/Bias, save/RestoreV2_54)]]

Answer

The method I posted in my edit was actually the correct way to get this to work. The error I got was because the data had changed between when I made the checkpoint and when I tried to load it.

Just to make it visible, loading from a V2 checkpoint in the code above was done by changing the line if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): to if ckpt and ckpt.model_checkpoint_path: