我正在尝试使用Keras库的预测函数,但它接受输入为4维数组。但是,我的数据生成器正在生成输入为二维数组。更具体一点;
预期输入---> [n,通道,宽度,高度]
我的输入---> [n,宽度*高度]
我需要在不更改数据生成器的情况下处理此问题。但是,我无法弄清楚如何覆盖Keras Library的预测功能。或者还有其他方法可以解决这个问题吗?
我的数据生成器
def genData(n=1000, max_digs=2, width=60):
capgen = ImageCaptcha()
data = []
target = []
for i in range(n):
x = np.random.randint(0, 10 ** max_digs)
img = misc.imread(capgen.generate(str(x)))
img = np.mean(img, axis=2)[:, :width]
data.append(img.flatten())
target.append(x)
return np.array(data), np.array(target)