如何加载和测试特定的TensorFlow保存的模型?

时间:2018-12-24 13:47:58

标签: python tensorflow

训练后,我得到了许多保存的模型。例如,在保存的模型文件夹中,我有3个保存的模型和一个名为:

checkpoint文件
checkpoint,  
model.ckpt-1000.data-00000-of-00001,  
model.ckpt-1000.index,  
model.ckpt-1000.meta,  
model.ckpt-2000.data-00000-of-00001,  
model.ckpt-2000.index,  
model.ckpt-2000.meta,  
model.ckpt-3000.data-00000-of-00001,  
model.ckpt-3000.index,  
model.ckpt-3000.meta,

我尝试了2种不同的方式:

第一

ckpt = tf.train.latest_checkpoint(CHECKPOINT_DIR)
saver.restore(sess, ckpt)

第二

ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
saver.restore(sess, ckpt.model_checkpoint_path)

他们都工作了!但是他们只能测试最新的模型。

如果要测试特定模型,则必须将model_checkpoint_path: "model.ckpt-3000"文件中的model_checkpoint_path: "model.ckpt-2000"修改为checkpoint

我的问题是如何一一测试所有模型? (或者,如何测试特定模型?)

1 个答案:

答案 0 :(得分:1)

您可以使用checkpoint.restore方法恢复特定的检查点。 除了文件名外,还必须指定index 。例如,假设您要在迭代1000处加载检查点,然后编写:

status = ckpnt.restore('./test/model.ckpt-1000')

另一次您需要在2000版迭代中加载检查点:

status = ckpnt.restore('./test/model.ckpt-2000')

完整示例

import tensorflow as tf

v1 = tf.Variable(9., name="v1")
v2 = tf.Variable(2., name="v2")
a = tf.add(v1, v2)

ckpnt = tf.train.Checkpoint(firstVar=v1, secondVar=v2)

with tf.Session() as sess:
    # Init v1 and v2
    sess.run(tf.global_variables_initializer())
    # Print value of v1
    print(sess.run(v1))
    # Save v1 and v2 variables
    ckpnt.save('./test/myVar', sess)
    sess.run(v1.assign(90))
    sess.run(v2.assign(20))
    ckpnt.save('./test/myVar', sess)

ckpnt = tf.train.Checkpoint(firstVar=v1, secondVar=v2)
status = ckpnt.restore('./test/myVar-1')
status.assert_consumed()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    status.initialize_or_restore(sess)
    print(sess.run(v1))