在同一个Python会话中保存和恢复Tensorflow图

时间:2017-08-29 15:56:45

标签: python tensorflow save restore

有几个相关问题,但似乎没有解决我的具体问题。

我写了一些保存和恢复TensorFlow模型的代码。如果我保存模型并在后续的python运行中恢复模型,一切都还可以。但是,如果我尝试在同一个Python实例中保存和恢复模型,则会出现以下错误:

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Foo/X:0", shape=(?, 4), dtype=float32) is not an element of this graph.

据我所知,变量" Foo / X"恢复后在图表中查看:

[n.name for n in tf.get_default_graph().as_graph_def().node]

我的代码的基本思想是使用对TensorFlow API的相同调用来创建/重新创建图形,然后使用tf.train.Saver()。restore()来恢复训练状态。一个简单的例子,给出了相同的错误(在函数Barfoo的最后一行)

import numpy as np
import tensorflow as tf

def Foobar():
    global R1
    with tf.variable_scope('Foo'):
        X = tf.placeholder("float", [None, 4], name = 'X')
        Y = tf.placeholder("float", [None], name = 'Y')
        W = tf.Variable(tf.ones([4, 1]), name = 'W')
        YH = tf.matmul(X, W, name = 'YH')
        L = tf.reduce_sum(tf.nn.l2_loss(YH - Y), name = 'L')
        O = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'O').minimize(L)
        init = tf.global_variables_initializer()
    S.run(init)
    for i in range(32):
        l1, _ = S.run([L, O], feed_dict = {X: x, Y: y})
        print(str(l1))
    R1 = S.run(YH, feed_dict = {X: np.ones((1, 4))})
    saver = tf.train.Saver()
    saver.save(S, "TFModel/savemodel")

def Barfoo():  
    global R2
    with tf.variable_scope('Foo'):
        X = tf.placeholder("float", [None, 4], name = 'X')
        Y = tf.placeholder("float", [None], name = 'Y')
        W = tf.Variable(tf.ones([4, 1]), name = 'W')
        YH = tf.matmul(X, W, name = 'YH')
        L = tf.reduce_sum(tf.nn.l2_loss(YH - Y), name = 'L')
        O = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'O').minimize(L)
    saver = tf.train.Saver()
    saver.restore(S, tf.train.latest_checkpoint('TFModel/'))
    print(str([n.name for n in tf.get_default_graph().as_graph_def().node]))
    R2 = S.run(YH, feed_dict = {X: np.ones((1, 4))})

x = np.random.rand(32, 4)
y = x.sum(axis = 1) + np.random.rand(32) / 10
S = tf.Session()
R1, R2 = None, None     
Foobar()
tf.reset_default_graph()
Barfoo()
print('R1: ' + str(R1))
print('R2: ' + str(R2)

为什么这段代码在尝试在Barfoo中使用变量X时出错?如果我第一次运行 Foobar ,终止程序,为什么会这样?然后运行 Barfoo

1 个答案:

答案 0 :(得分:0)

以粗体回答问题,即:

为什么此代码在尝试在Barfoo中使用变量X时出错?

使用S.run,您需要计算要计算的张量的名称,以及带有张量名称作为键的feed_dict的字典。您试图将张量对象本身作为键传递,而不是将它们的名称传递给它。比较:

WRONG)

R2 = S.run("Foo/YH:0", feed_dict={"Foo/X:0": np.ones((1, 4))})

RIGHT)

import scala.io.Source

请注意,我正在通知张量的名称,而不是张量器本身(正如您的版本中所做的那样)。只需更改上面的行即可使代码正常工作。

关于第二个问题,请更清楚一点如何重现它,以便我可以检查发生了什么。