我刚刚将Tensorflow的本地安装更新为0.11rc2,我收到一条消息,说我应该为我的保护程序添加一个参数,使其保存在版本2.我更新了这个,现在我无法加载保存在的模型这种格式。当我运行我的模型时,它会在每个时代之后保存。保存时,它用于保存名为translate.ckpt-3916.meta
和translate.ckpt-3916.index
的文件。现在我得到三个文件而不是两个,名为translate.ckpt-3916.meta
,translate.ckpt-3916.data-000000-of-000001
和ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
session.run(tf.initialize_all_variables())
return model
。
要加载数据,我使用以下代码:
model
其中ckpt.model_checkpoint_path
是已使用我的程序的标准超参数初始化的模型对象。这与saver v1没有问题。无论版本如何,translate.ckpt-3916
都会评估checkpoint
的路径,因此如果检查点是使用v2保存的,则找不到文件。
该目录中model_checkpoint_path: "translate.ckpt-3916"
all_model_checkpoint_paths: "translate.ckpt-3916"
文件的内容(使用任一版本保存时)为:
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
是否有新方法使用saver v2加载数据?否则,我如何加载检查点?
编辑:
this question中显示的将行if ckpt and ckpt.model_checkpoint_path:
更改为InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [84] rhs shape= [98]
[[Node: save/Assign_54 = Assign[T=DT_FLOAT, _class=["loc:@NLC/Logistic/Linear/Bias"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](NLC/Logistic/Linear/Bias, save/RestoreV2_54)]]
似乎可以更进一步,但会引发以下错误:
{{1}}
答案 0 :(得分:2)
我在编辑中发布的方法实际上是使其正常工作的正确方法。我得到的错误是因为我在制作检查点和尝试加载检查点之间的数据发生了变化。
只是为了让它可见,通过将行if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
更改为if ckpt and ckpt.model_checkpoint_path: