带有图像的numpy vstack

时间:2019-04-03 10:43:03

标签: python numpy python-imaging-library

我正在编写一个从图像中读取像素数据并将其存储在numpy数组中的功能,以进一步进行训练/测试拆分。

运行此代码时,它将引发异常,指出除串联轴外的所有输入数组维必须完全匹配。

我不确定为什么会发生此问题以及如何解决此问题。

from PIL import Image
import numpy as np
import os

X = np.array([])
y = []

categories = {
    'A': 1,
    'B': 2
}

root = data_dir + '/cropped_resized(128,128)/'

for path, subdirs, files in os.walk(root):
    for name in files:
        img_path = os.path.join(path,name)
        category = categories[os.path.basename(path)]
        im = Image.open(img_path)
        img_pixels = list(im.getdata())
        width, height = im.size
        X = np.vstack((X, img_pixels))
        #X = np.concatenate((X, img_pixels), axis=0)
        y.append(category)

X_train, X_test, y_train, y_test = train_test_split(X, y)

这是一张失败的图片的例子

enter image description here

1 个答案:

答案 0 :(得分:1)

确定是否要将图像设置为RGB或灰度,并确保它们在加载时如此。

具体来说,更改此行:

im = Image.open(img_path)

im = Image.open(img_path).convert('RGB')

im = Image.open(img_path).convert('L')