tf.reshape在您添加额外尺寸的情况下不起作用

时间:2018-12-22 21:10:05

标签: python tensorflow

根据tensorflow网站,tf.reshape采用某个形状的张量并将其映射到另一个形状的张量。我想将大小为[600,64]的张量映射为大小为[-1、8、8、1]的张量(其中-1位置的尺寸为600)。但这似乎不起作用。

我正在python 3.6上的tensorflow上运行它,尽管它重塑为[-1、8、8],但不会重塑为[-1、8、8、1]

import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import LabelBinarizer

# preprocessing method needed
    def flatten(array):
        temp = []
        for j in array:
            temp.extend(j)
        return temp

# preprocess the data
digits = datasets.load_digits()
images = digits.images
images = [flatten(i) for i in images]
labels = digits.target
labels = LabelBinarizer().fit_transform(labels)

# the stats needed
width = 8
height = 8
alpha = 0.1
num_labels = 10
kernel_length = 3
batch_size = 10
channels = 1

# the tensorflow placeholders and reshaping
X = tf.placeholder(tf.float32, shape = [None, width * height * channels])

# AND NOW HERE IS WHERE THE ERROR STARTS
y_true = tf.placeholder(tf.float32, shape = [None, num_labels])
X = tf.reshape(X, [-1, 8, 8, 1])

# the convolutional model
conv1 = tf.layers.conv2d(X, filters = 32, kernel_size = [kernel_length,  kernel_length])
conv2 = tf.layers.conv2d(conv1, filters = 64, kernel_size = [2, 2])
flatten = tf.reshape(X, [-1, 1])
dense1 = tf.layers.dense(flatten, units=50, activation = tf.nn.relu)
y_pred = tf.layers.dense(dense1, units=num_labels, activation = tf.nn.softmax)

# the loss and training functions
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)
train = tf.train.GradientDescentOptimizer(alpha).minimize(loss)

# initializing the variables and the tf.session
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# running the session
for i in range(batch_size):
    _, lossVal = sess.run((train, loss), feed_dict = {X:images[:600], y_true: labels[:600]})
    print(lossVal)

我不断遇到此错误: ValueError:无法为形状为((?,8,8,1)'的张量'Reshape:0'输入形状(600,64)的值 而且我觉得情况并非如此,因为8 * 8 * 1等于64。

1 个答案:

答案 0 :(得分:0)

images[:600]的形状为(600, 64),与占位符预期形状(None, 8, 8, 1)不符。

重塑数据或更改占位符的形状。

请注意,您最初将占位符形状定义为(None, 64)的事实并不重要,因为稍后再对其进行整形。