如何训练两个频率不同的班级的cnn?

时间:2019-03-28 08:26:37

标签: python keras deep-learning conv-neural-network

我正在训练一个简单的卷积神经网络(CNN),该网络应执行二进制分类。我使用的软件包是keras。 我需要的是训练不平衡。例如,应该对一个班级进行900张图像的培训,而另一班应仅对300张图像进行培训。

我正在使用的代码如下:

from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense

classifier = Sequential()
classifier.add(Conv2D(32, (3, 3),
                      input_shape=(64, 64, 3),
                      activation='relu'))

classifier.add(MaxPooling2D(pool_size=(2, 2)))
classifier.add(Flatten())
classifier.add(Dense(units=128, activation='relu'))
classifier.add(Dense(units=1, activation='sigmoid'))
classifier.compile(optimizer='adam',
                   loss='binary_crossentropy',
                   metrics=['accuracy'])

from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)
training_set = train_datagen.flow_from_directory('dataset/training_set',
                                                 target_size=(64, 64),
                                                 batch_size=32,
                                                 class_mode='binary')

test_set = test_datagen.flow_from_directory('dataset/test_set',
                                            target_size=(64, 64),
                                            batch_size=32,
                                            class_mode='binary')

classifier.fit_generator(training_set,
                         steps_per_epoch=1200,
                         epochs=30,
                         validation_data=test_set,
                         validation_steps=50) 

现在,正在使用批处理大小为32的模型进行训练。 我猜这意味着它需要从一个班级中拿出16个培训示例,从另一个班级中拿出16个培训示例? 我需要的是从其中一个班级中选取24个培训示例,从另一个班级中选取8个示例。 可能我应该以某种方式修改与训练数据集有关的flow_from_directory()函数。不幸的是,keras文档中没有与此相关的内容。 你有什么建议吗?

0 个答案:

没有答案