如何覆盖keras预测功能?

时间:2017-05-13 09:35:35

标签: python tensorflow keras

我正在尝试使用Keras库的预测函数,但它接受输入为4维数组。但是,我的数据生成器正在生成输入为二维数组。更具体一点;

预期输入---> [n,通道,宽度,高度]

我的输入---> [n,宽度*高度]

我需要在不更改数据生成器的情况下处理此问题。但是,我无法弄清楚如何覆盖Keras Library的预测功能。或者还有其他方法可以解决这个问题吗?

我的数据生成器

def genData(n=1000, max_digs=2, width=60):
    capgen = ImageCaptcha()
    data = []
    target = []
    for i in range(n):
        x = np.random.randint(0, 10 ** max_digs)
        img = misc.imread(capgen.generate(str(x)))
        img = np.mean(img, axis=2)[:, :width]
        data.append(img.flatten())
        target.append(x)
    return np.array(data), np.array(target)

0 个答案:

没有答案