错误':遵循数据格式约定的预期输入是图像(作为Numpy数组)" channels_first"'在Keras图像示例中

时间:2017-07-24 20:43:23

标签: python numpy keras

运行keras示例脚本时遇到以下错误:

C:\Users\johnd\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\preprocessing\image.py:653: UserWarning: Expected input to be images (as Numpy array) following the data format convention "channels_first" (channels on axis 1), i.e. expected either 1, 3 or 4 channels on axis 1. However, it was passed an array with shape (60000, 1, 28, 28) (1 channels).
  ' (' + str(x.shape[self.channel_axis]) + ' channels).')

这是我的剧本:

from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K
K.set_image_dim_ordering('th')
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# reshape to be [samples][pixels][width][height]
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
# convert from int to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
# define data preparation
datagen = ImageDataGenerator(zca_whitening=True)
# fit parameters from data
datagen.fit(X_train)
# configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
    # create a grid of 3x3 images
    for i in range(0, 9):
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(X_batch[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))
    # show the plot
    pyplot.show()
    break

你能告诉我出了什么问题吗?

谢谢。

1 个答案:

答案 0 :(得分:0)

这不是错误,而是警告 - 我猜你没有提到其余的代码运行正常;这是我在运行你的脚本后得到的:

enter image description here

X_train.shape (60000, 1, 28, 28)并发出类似警告:

/var/venv/tsats/local/lib/python2.7/site-packages/keras/preprocessing/image.py:648: 
UserWarning: Expected input to be images (as Numpy array) following the data format 
convention "channels_first" (channels on axis 1), i.e. expected either 1, 3 or 4 channels on 
axis 1. However, it was passed an array with shape (60000, 1, 28, 28) (1 channels).    ' (' + str(x.shape[self.channel_axis]) + ' channels).')

仔细检查警告显示它是荒谬的:它要求通道位于轴1上(它们是)并且它们具有{1, 3, 4}中的值(它们的有效值确实为1)。显然这是一个Keras错误 - 我打开了一个issue - 但它不会影响代码的其余部分和结果。

更新:根据Keras Github中的response问题,这已经在主人身上修好了(尽管我自己还没有经过测试)。