我正在编写一个从图像中读取像素数据并将其存储在numpy数组中的功能,以进一步进行训练/测试拆分。
运行此代码时,它将引发异常,指出除串联轴外的所有输入数组维必须完全匹配。
我不确定为什么会发生此问题以及如何解决此问题。
from PIL import Image
import numpy as np
import os
X = np.array([])
y = []
categories = {
'A': 1,
'B': 2
}
root = data_dir + '/cropped_resized(128,128)/'
for path, subdirs, files in os.walk(root):
for name in files:
img_path = os.path.join(path,name)
category = categories[os.path.basename(path)]
im = Image.open(img_path)
img_pixels = list(im.getdata())
width, height = im.size
X = np.vstack((X, img_pixels))
#X = np.concatenate((X, img_pixels), axis=0)
y.append(category)
X_train, X_test, y_train, y_test = train_test_split(X, y)
这是一张失败的图片的例子
答案 0 :(得分:1)
确定是否要将图像设置为RGB或灰度,并确保它们在加载时如此。
具体来说,更改此行:
im = Image.open(img_path)
到
im = Image.open(img_path).convert('RGB')
或
im = Image.open(img_path).convert('L')