tf.placeholder不读取rbg颜色参数

时间:2017-04-04 14:30:17

标签: tensorflow

我正在尝试将一个300 x 300 x 300张量馈送到tensorflow中的占位符,虽然我正在使用imgs = tf.placeholder(tf.float32, [1, 300, 300, 300, 3])但我得到了ValueError

Traceback (most recent call last):
File "tfvgg.py", line 265, in <module>
prob = sess.run(vgg.probs, feed_dict={vgg.imgs: train_data})
File "/home/entelechy/tfenv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/home/entelechy/tfenv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 894, in _run
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (1, 300, 300, 300) for Tensor 'Placeholder:0', which has shape '(1, 300, 300, 300, 3)'

我是tensorflow的新手,是什么让输入尺寸与占位符兼容?

if __name__ == '__main__':
    sess = tf.Session()
    imgs = tf.placeholder(tf.float32, [1, 300, 300, 300, 3])
    vgg = vgg16(imgs, None, sess)

    # Data
    INPUT_FOLDER = 'data/cubed_data/pp/play'
    images = os.listdir(INPUT_FOLDER)
    images.sort()

    train_data = []
    for i in images:
        im = np.load(INPUT_FOLDER + "/" + i)
        train_data.append(im)

    prob = sess.run(vgg.probs, feed_dict={vgg.imgs: train_data})
    preds = (np.argsort(prob)[::-1])[0:5]
    for p in preds:
        print(class_names[p], prob[p])

1 个答案:

答案 0 :(得分:1)

您尝试进给的张量具有(1,300,300,300)的形状,并且您的占位符具有形状(1,300,300,300,3)。两者必须具有相同的形状,因此将占位符更改为:

imgs = tf.placeholder(tf.float32, [1, 300, 300, 300])

修改

随着新的错误消息,我们越来越接近问题的根源。

首先,您必须将输入带入形状(1,300,300,300,1)。由于您将图像放入列表中,因此第一个维度已存在。必须添加最后一个维度:

im = np.load(INPUT_FOLDER + "/" + i)
im = np.reshape(im, (300, 300, 300, 1))
train_data.append(im)

占位符必须设置为:

imgs = tf.placeholder(tf.float32, [1, 300, 300, 300, 1])

您的第二个问题是您的网络需要具有3个频道(rgb 3D数据)的3D图像,但您的图像只有一个频道。您必须更改网络中第一个卷积的内核维度(显然卷积位于第30行的文件“tfvgg.py”中)。

你必须找到创建该卷积的内核的位置,并将其维度更改为[无论如何,无论如何,1,无论如何]。现在第二个最后一个维度是3,但必须更改为1.此后它应该使用1个图像。

如果您以后想要批量同时输入多个图像,则必须将占位符的第一个维度调整为:

imgs = tf.placeholder(tf.float32, [number_of_images, 300, 300, 300, 1])

或任意图像编号,您也可以使用:

imgs = tf.placeholder(tf.float32, [None, 300, 300, 300, 1])

您使用的网络可能只支持1个图像,但它可能支持任意批量大小。