我正在尝试设计多输入keras模型。我们正在处理数学结的图像(128x128x3)。我建立了一个需要三个输入的模型。三个输入将是
1)。不旋转的结图
2)。与1相同的结,但绕其y轴旋转90度
3)。与1相同的结,但绕其x轴旋转了90度
该模型很好,可以正确编译。我遇到的问题是我正在使用fit_generator训练模型,而我似乎无法使数据生成器正常工作。这是我的数据生成器代码:
def DataGen(in1,in2,in3,in1_label,in2_label,in3_label, batch_size):
in1 = np.array(in1)
in1 = np.reshape(in1, (in1.shape[0],128,128,3))
in2 = np.array(in2)
in2 = np.reshape(in2, (in2.shape[0],128,128,3))
in3 = np.array(in3)
in3 = np.reshape(in3, (in3.shape[0],128,128,3))
L = len(in1)
batch_start = 0
batch_end = batch_size
gen = ImageDataGenerator(rescale=1.0/255)
genX1 = gen.flow(in1, in1_label, batch_size=batch_size, seed=1)
genX2 = gen.flow(in2, in2_label, batch_size=batch_size, seed=1)
genX3 = gen.flow(in3, in3_label, batch_size=batch_size, seed=1)
#this line is just to make the generator infinite, keras needs that
while True:
limit = min(batch_end, L)
#in1
X1i = genX1.next()
print(X1i[0].shape)
#in2
X2i = genX2.next()
#in3
X3i = genX3.next()
#print(Y.shape)
#print(Y1.shape)
#print(Y2.shape)
label = np.concatenate([X1i[1],X2i[1],X3i[1]])
#print(label.shape)
#print(X1i[1].shape)
yield [X1i[0],X2i[0],X3i[0]],np.array(label) #a tuple with two numpy arrays with batch_size samples
batch_start += batch_size
batch_end += batch_size
if batch_start > L - batch_size:
batch_start = 0
batch_end = batch_size
如果我使用此代码运行神经网络,它将生成以下错误消息:
Input arrays should have the same number of samples as target arrays. Found 30 input samples and 90 target samples.
这让我觉得我不应该连接标签,而只是返回标签列表以及批次列表...如果我得到以下内容
yield [X1i[0],X2i[0],X3i[0]],[X1i[1],X2i[1],X3i[1]]
我收到以下错误消息:
Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), but instead got the following list of 3 arrays: [array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0....
所以我不知道该怎么办。该模型有45,000个输入图像,三个输入中的每一个都有15,000个。该模型还将具有45,000个验证图像,每个输入还有15,000个验证图像。
每个输入(15,000张图像集)都有一个具有以下结构的相应标签:列表的前5,000个元素=标签0,列表的后5,000个元素=标签1,列表的前5,000个元素=标签2。标签也全部是One编码的。
任何建议都将不胜感激。
我的数据生成器基于: https://stackoverflow.com/a/49405175/5432071
我的Jupyter笔记本代码可以在这里查看: https://uofstthomasmn-my.sharepoint.com/:u:/g/personal/ward0001_stthomas_edu/EcHhuXpXl1VJu8yKZaWRdKkBdwVrD6AEs3hd4Kuwk2Cl3g?e=tlQTGx
对于Jupyer笔记本电脑链接,您将必须单击该链接,然后单击下载...,因为html文件将不会显示在一个驱动器上。
总结我的问题:多输入模型的数据生成器有什么问题?
谢谢。