TensorFlow无法识别feed_dict输入

时间:2018-12-30 20:50:32

标签: python tensorflow deep-learning

我正在运行一个用于线性回归的简单神经网络。但是TensorFlow抱怨我的feed_dict占位符不是图形的元素。但是,我的占位符和模型都在图形中定义,如下所示:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense

with tf.Graph().as_default():
    x = tf.placeholder(dtype=tf.float32, shape = (None,4))
    y = tf.placeholder(dtype=tf.float32, shape = (None,4))

    model = tf.keras.Sequential([
        Dense(units=4, activation=tf.nn.relu)
    ])
    y = model(x)

    loss = tf.reduce_mean(tf.square(y-x))
    train_op = tf.train.AdamOptimizer().minimize(loss)

with tf.Session() as sess:
    sess.run(train_op, feed_dict = {x:np.ones(dtype='float32', shape=(4)),
                                    y:5*np.ones(dtype='float32', shape=(4,))})

这给出了一个错误:

TypeError: Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", shape=(?, 4), dtype=float32) is not an element of this graph.

____________ UPDATE ________________

根据@Silgon和@Mcangus的建议,我修改了代码:

g= tf.Graph()
with g.as_default():
    x = tf.placeholder(dtype=tf.float32, shape = (None,4))

    model = tf.keras.Sequential([
        Dense(units=4, activation=tf.nn.relu)
    ])
    y = model(x)

    loss = tf.reduce_mean(tf.square(y-x))
    train_op = tf.train.AdamOptimizer().minimize(loss)

    init_op = tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer())
with tf.Session(graph=g) as sess:
    sess.run(init_op)
    for i in range(5):
        _ , answer = sess.run([train_op,loss], feed_dict = {x:np.ones(dtype='float32', shape=(1,4)),
                                                        y:5*np.ones(dtype='float32', shape=(1,4))})
        print(answer)

但是该模型似乎不是在学习:

16.0
16.0
16.0
16.0
16.0

2 个答案:

答案 0 :(得分:2)

该错误告诉您变量不是图形的元素。可能是因为它不在同一范围内。解决该问题的一种方法是采用如下结构。

# define a graph
graph = tf.Graph()
with graph.as_default():
    # placeholder
    x = tf.placeholder(...)
    y = tf.placeholder(...)
    # create model
    model = create_model(x, w, b)

with tf.Session(graph=graph) as sess:
    # initialize all the variables
    sess.run(init)

另外,正如@Mcangus指出的那样,请谨慎定义变量。

答案 1 :(得分:1)

我相信您的问题是此行:

y = model(x)

您用模型的输出覆盖了y,因此它不再是占位符。