我尝试恢复会话并调用get_variable()
来获取类型的对象
tf.Variable(根据this answer)。
它无法找到变量。重现案例的最小例子是
如下。
首先,创建一个变量并保存会话。
import tensorflow as tf
var = tf.Variable(101)
with tf.Session() as sess:
with tf.variable_scope(''):
scoped_var = tf.get_variable('scoped_var', [])
with tf.variable_scope('', reuse=True):
new_scoped_var = tf.get_variable('scoped_var', [])
assert scoped_var is new_scoped_var
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
print(sess.run(scoped_var))
saver.save(sess, 'data/sess')
get_variables
范围内的reuse=True
工作正常。
然后,从文件中恢复会话并尝试获取变量。
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('data/sess.meta')
saver.restore(sess, 'data/sess')
for v in tf.get_collection('variables'):
print(v.name)
print(tf.get_collection(("__variable_store",)))
# Oops, it's empty!
with tf.variable_scope('', reuse=True):
# the next line fails
new_scoped_var = tf.get_variable('scoped_var', [])
print("new_scoped_var: ", new_scoped_var)
输出:
Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?
我们可以看到,get_variable()
无法找到变量。和
("__variable_store",)
内部使用的get_variable()
集合,
是空的。
为什么get_variable
会失败?
答案 0 :(得分:1)
您可以试试这个,而不是处理元图(如果你想修改图表以及它是如何加载的话,这可能会有所帮助)。
import tensorflow as tf
with tf.Session() as sess:
with tf.variable_scope(''):
scoped_var = tf.get_variable('scoped_var', [])
with tf.variable_scope('', reuse=True):
new_scoped_var = tf.get_variable('scoped_var', [])
assert scoped_var is new_scoped_var
saver = tf.train.Saver()
path = tf.train.get_checkpoint_state('data/sess')
if path is not None:
saver.restore(sess, path.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
print(sess.run(scoped_var))
saver.save(sess, 'data/sess')
#now continue to use as you normally would with a restored model
主要区别在于您在调用saver.restore
之前设置了模型