Tensorflow无法从检查点恢复global_step

时间:2017-04-04 22:27:15

标签: tensorflow

我似乎无法从保存的检查点中检索global_step。我的代码:

//(...)
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file), clear_devices=True)
saver.restore(sess, checkpoint_file)
for v in tf.global_variables():
    print(v)
test = tf.get_variable("global_step")
print(test)

结果:

//(...)
Tensor("global_step/read:0", shape=(), dtype=int32)
//(...)
Traceback (most recent call last):
  File "train.py", line XXX, in <module>
    test = tf.get_variable("global_step")
  File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 988, in get_variable
    custom_getter=custom_getter)
  File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 890, in get_variable
    custom_getter=custom_getter)
  File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 348, in get_variable
    validate_shape=validate_shape)
  File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 333, in _true_getter
    caching_device=caching_device, validate_shape=validate_shape)
  File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 660, in _get_single_variable
    "but instead was +1ms." % (name, shape))
ValueError: Shape of a new variable (global_step) must be fully defined, but instead was <unknown>.

我也试过global_step:0global_step/read:0,但结果相同。有小费吗?或者我不应该使用tf.get_variable

谢谢

2 个答案:

答案 0 :(得分:1)

如果首先使用tf.get_variable创建了该变量,则只能使用tf.get_variable来检索现有变量。此外,必须适当设置变量范围。它似乎在这里尝试创建一个名为'global_step'的新变量,表明它尚不存在。 Here是有关如何使用tf.get_variable的更多信息。

我通常会像这样处理全局步骤:

# to create
global_step = tf.Variable(tf.constant(0), trainable=False, name='global_step')
tf.add_to_collection('global_step', global_step)

# to load
global_step = tf.get_collection_ref('global_step')[0]
# get the current value
gs = sess.run(global_step)

编辑:如果您无法更改保存全局步骤的方式,则以下内容应该有效:

global_step = tf.get_default_graph().get_tensor_by_name('global_step:0')

答案 1 :(得分:0)

你可以这样做:

with tf.Session() as sess:
    predict_top_5 = tf.nn.top_k(scores, k=5)
    label_top_5 = tf.nn.top_k(input_y, k=5)
    ckpt = tf.train.get_checkpoint_state('models')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
        global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])