我似乎无法从保存的检查点中检索global_step
。我的代码:
//(...)
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file), clear_devices=True)
saver.restore(sess, checkpoint_file)
for v in tf.global_variables():
print(v)
test = tf.get_variable("global_step")
print(test)
结果:
//(...)
Tensor("global_step/read:0", shape=(), dtype=int32)
//(...)
Traceback (most recent call last):
File "train.py", line XXX, in <module>
test = tf.get_variable("global_step")
File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 988, in get_variable
custom_getter=custom_getter)
File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 890, in get_variable
custom_getter=custom_getter)
File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 348, in get_variable
validate_shape=validate_shape)
File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 333, in _true_getter
caching_device=caching_device, validate_shape=validate_shape)
File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 660, in _get_single_variable
"but instead was +1ms." % (name, shape))
ValueError: Shape of a new variable (global_step) must be fully defined, but instead was <unknown>.
我也试过global_step:0
和global_step/read:0
,但结果相同。有小费吗?或者我不应该使用tf.get_variable
?
谢谢
答案 0 :(得分:1)
如果首先使用tf.get_variable
创建了该变量,则只能使用tf.get_variable
来检索现有变量。此外,必须适当设置变量范围。它似乎在这里尝试创建一个名为'global_step'
的新变量,表明它尚不存在。 Here是有关如何使用tf.get_variable
的更多信息。
我通常会像这样处理全局步骤:
# to create
global_step = tf.Variable(tf.constant(0), trainable=False, name='global_step')
tf.add_to_collection('global_step', global_step)
# to load
global_step = tf.get_collection_ref('global_step')[0]
# get the current value
gs = sess.run(global_step)
编辑:如果您无法更改保存全局步骤的方式,则以下内容应该有效:
global_step = tf.get_default_graph().get_tensor_by_name('global_step:0')
答案 1 :(得分:0)
with tf.Session() as sess:
predict_top_5 = tf.nn.top_k(scores, k=5)
label_top_5 = tf.nn.top_k(input_y, k=5)
ckpt = tf.train.get_checkpoint_state('models')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])