有几个相关问题,但似乎没有解决我的具体问题。
我写了一些保存和恢复TensorFlow模型的代码。如果我保存模型并在后续的python运行中恢复模型,一切都还可以。但是,如果我尝试在同一个Python实例中保存和恢复模型,则会出现以下错误:
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Foo/X:0", shape=(?, 4), dtype=float32) is not an element of this graph.
据我所知,变量" Foo / X"恢复后在图表中查看:
[n.name for n in tf.get_default_graph().as_graph_def().node]
我的代码的基本思想是使用对TensorFlow API的相同调用来创建/重新创建图形,然后使用tf.train.Saver()。restore()来恢复训练状态。一个简单的例子,给出了相同的错误(在函数Barfoo的最后一行):
import numpy as np
import tensorflow as tf
def Foobar():
global R1
with tf.variable_scope('Foo'):
X = tf.placeholder("float", [None, 4], name = 'X')
Y = tf.placeholder("float", [None], name = 'Y')
W = tf.Variable(tf.ones([4, 1]), name = 'W')
YH = tf.matmul(X, W, name = 'YH')
L = tf.reduce_sum(tf.nn.l2_loss(YH - Y), name = 'L')
O = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'O').minimize(L)
init = tf.global_variables_initializer()
S.run(init)
for i in range(32):
l1, _ = S.run([L, O], feed_dict = {X: x, Y: y})
print(str(l1))
R1 = S.run(YH, feed_dict = {X: np.ones((1, 4))})
saver = tf.train.Saver()
saver.save(S, "TFModel/savemodel")
def Barfoo():
global R2
with tf.variable_scope('Foo'):
X = tf.placeholder("float", [None, 4], name = 'X')
Y = tf.placeholder("float", [None], name = 'Y')
W = tf.Variable(tf.ones([4, 1]), name = 'W')
YH = tf.matmul(X, W, name = 'YH')
L = tf.reduce_sum(tf.nn.l2_loss(YH - Y), name = 'L')
O = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'O').minimize(L)
saver = tf.train.Saver()
saver.restore(S, tf.train.latest_checkpoint('TFModel/'))
print(str([n.name for n in tf.get_default_graph().as_graph_def().node]))
R2 = S.run(YH, feed_dict = {X: np.ones((1, 4))})
x = np.random.rand(32, 4)
y = x.sum(axis = 1) + np.random.rand(32) / 10
S = tf.Session()
R1, R2 = None, None
Foobar()
tf.reset_default_graph()
Barfoo()
print('R1: ' + str(R1))
print('R2: ' + str(R2)
为什么这段代码在尝试在Barfoo中使用变量X时出错?如果我第一次运行 Foobar ,终止程序,为什么会这样?然后运行 Barfoo ?
答案 0 :(得分:0)
以粗体回答问题,即:
为什么此代码在尝试在Barfoo中使用变量X时出错?
使用S.run,您需要计算要计算的张量的名称,以及带有张量名称作为键的feed_dict的字典。您试图将张量对象本身作为键传递,而不是将它们的名称传递给它。比较:
WRONG)
R2 = S.run("Foo/YH:0", feed_dict={"Foo/X:0": np.ones((1, 4))})
RIGHT)
import scala.io.Source
请注意,我正在通知张量的名称,而不是张量器本身(正如您的版本中所做的那样)。只需更改上面的行即可使代码正常工作。
关于第二个问题,请更清楚一点如何重现它,以便我可以检查发生了什么。