我正在对大约3000张图像的数据集进行多标签图像分类。因为这是我工作内存的限制,并且数据集将增加,所以我尝试实现自己的生成器,因为我还从在线资源中解析了图像。该网络的准确度达到25%,其中三个最高准确度的标签可以很好地呈现图像。
正常批次的形状为(32、64、64、3),标签的形状为(32、57)。 我的模型如下:
def createModel(shape, classes):
x = shape[0]
y = shape[1]
z = shape[2]
model = Sequential()
model.add(Conv2D(32, (2, 2), padding='same',input_shape=(x,y,z)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(32, (2, 1), padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(32, (1, 2), padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides = 2))
model.add(Conv2D(48, (2, 2), padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides = 2))
model.add(Conv2D(80, (2, 2), padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides = 2))
model.add(Flatten())
model.add(Dense(classes, activation = 'sigmoid'))
其中类为57,x,y,z为64,64,3。
我的生成器如下:
def generator(data, urls, labels, batch_size):
counter = 0
X_train = []
Y_train = []
while 1:
for i in range(len(data)):
if counter == batch_size:
yield (np.array(X_train), np.array(Y_train))
X_train = []
Y_train = []
counter = 0
try:
ID = data[i][0]
if random.uniform(0,1) > 0.5:
X_train.append(getImage(64, urls[ID]))
else:
X_train.append(np.flip(getImage(64, urls[ID]),1))
Y_train.append(labels[i])
counter+=1
except:
continue
其中data是具有图像ID和标签的列表,urls是具有图像ID和查找图像的URL的列表,label是由MultiLabelBinarizer()(.fit_transform)和batch_size转换的标签。 getImage()函数生成一个np.array(),其中64表示形状。
主要电话:
epochs = 60
lr = 1e-6
mlc = model.createModel((64,64,3), 57)
opt = Adam(lr=lr, decay=lr / epochs)
trainGenerator = data.generator(structuredData, urls, mlb_labels, 32)
validationGenerator = data.generator(structuredData, urls, mlb_labels, 32)
mlc.compile(loss="categorical_crossentropy", optimizer=opt, metrics=
["accuracy"])
mlc.fit_generator(trainGenerator, steps_per_epoch = 10, epochs=epochs,
validation_steps = 1, validation_data=validationGenerator)
mlc.save("datagenerator_test.h5")
此外,如果我不使用发电机,则网络已经可以正常工作并进行训练,对于发电机而言,它似乎获得1%至3%的随机精度。我希望它能提供足够的信息。
编辑:我大约需要90秒来准备一批32张图像。培训是否等待批次准备就绪?