训练后,我得到了许多保存的模型。例如,在保存的模型文件夹中,我有3个保存的模型和一个名为:
的checkpoint
文件
checkpoint,
model.ckpt-1000.data-00000-of-00001,
model.ckpt-1000.index,
model.ckpt-1000.meta,
model.ckpt-2000.data-00000-of-00001,
model.ckpt-2000.index,
model.ckpt-2000.meta,
model.ckpt-3000.data-00000-of-00001,
model.ckpt-3000.index,
model.ckpt-3000.meta,
我尝试了2种不同的方式:
第一:
ckpt = tf.train.latest_checkpoint(CHECKPOINT_DIR)
saver.restore(sess, ckpt)
第二:
ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
saver.restore(sess, ckpt.model_checkpoint_path)
他们都工作了!但是他们只能测试最新的模型。
如果要测试特定模型,则必须将model_checkpoint_path: "model.ckpt-3000"
文件中的model_checkpoint_path: "model.ckpt-2000"
修改为checkpoint
。
我的问题是如何一一测试所有模型? (或者,如何测试特定模型?)
答案 0 :(得分:1)
您可以使用checkpoint.restore
方法恢复特定的检查点。 除了文件名外,还必须指定index
。例如,假设您要在迭代1000处加载检查点,然后编写:
status = ckpnt.restore('./test/model.ckpt-1000')
另一次您需要在2000版迭代中加载检查点:
status = ckpnt.restore('./test/model.ckpt-2000')
完整示例:
import tensorflow as tf
v1 = tf.Variable(9., name="v1")
v2 = tf.Variable(2., name="v2")
a = tf.add(v1, v2)
ckpnt = tf.train.Checkpoint(firstVar=v1, secondVar=v2)
with tf.Session() as sess:
# Init v1 and v2
sess.run(tf.global_variables_initializer())
# Print value of v1
print(sess.run(v1))
# Save v1 and v2 variables
ckpnt.save('./test/myVar', sess)
sess.run(v1.assign(90))
sess.run(v2.assign(20))
ckpnt.save('./test/myVar', sess)
ckpnt = tf.train.Checkpoint(firstVar=v1, secondVar=v2)
status = ckpnt.restore('./test/myVar-1')
status.assert_consumed()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
status.initialize_or_restore(sess)
print(sess.run(v1))