我创建了一个简单的CNN,以区分5种不同的花朵。我想扩展CNN以识别更多对象。例如,我希望CNN能够识别一杯啤酒,窗户,树木等的图像。 下面是我对花进行分类的代码,它的效果很好。但是如何扩展它并使其识别越来越多的对象。我不想使用任何预先训练的模型。我希望它学习对更多对象进行分类。请帮忙。
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D, Flatten, Dense
classifier=Sequential()
#1st Convolution Layer
classifier.add(Convolution2D(32, 3, 3, input_shape=(64,64,3),activation="relu"))
#Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Adding a second convolutional layer
classifier.add(Convolution2D(32, 3, 3, activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Flattening
classifier.add(Flatten())
classifier.add(Dense(output_dim = 128, activation = 'relu'))
classifier.add(Dense(output_dim = 64, activation = 'relu'))
classifier.add(Dense(output_dim = 5, activation = 'softmax'))
classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
print(classifier.summary())
train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set= train_datagen.flow_from_directory('flowers/train_set',
target_size=(64,64),
batch_size=32,
class_mode='categorical')
test_set= test_datagen.flow_from_directory('flowers/test_set',
target_size=(64,64),
batch_size=32,
class_mode='categorical')
classifier.fit_generator(training_set,
samples_per_epoch = 3000,
nb_epoch = 25,
validation_data = test_set,
nb_val_samples=1000)
答案 0 :(得分:1)
好吧,您正在做的事情叫做转移学习(微调),让我举一个例子:Imganet是世界上最大的视觉识别图像数据库,它包含来自动物,汽车等线轴世界对象的1000个类...,原始的神经网络(例如VGG16,Inception Net)经过训练,可以根据此数据集重新识别每个对象。假设您有一个非常小的数据集(例如1000张图像),并且想要将其分类为3个类,但是由于图像站点太小,您的网络失败了,您选择了VGG16或Inception Net切下了最后一个。
classifier.add(Dense(output_dim = 1000, activation = 'softmax'))
classifier.add(Dense(output_dim = 3, activation = 'softmax'))
,然后重新训练最后一层,因此在松散的密集层上进行了决策或分类,其大小定义了要将每个输入分类到多少个分类。
答案 1 :(得分:1)
如果您想使用自己的模型而不是微调像Vgg或inception这样的预训练模型,例如,您应该阅读本文:
iCaRL an incremnetal network (paper)
当然,您必须更改算法和代码。我找到了这个github仓库,显然他们已经实现了:Github repo for iCaRL in tensorflow
但是您必须使用tensorflow。查看它以了解如何在您的模型中使用它(如果可能的话,我今天才找到本文和此仓库,因此我还没有对其进行研究。)
您要问的仍然是研究领域,因此还没有广泛或通用的技术。 就像我在评论中说的那样,搜索关键字“ incremental learning”,关于这个主题还有其他论文。 (请参阅iCaRL论文的相关工作会议,该主题的所有主要技术和论文都得到了很好的总结!)
另外,请注意,添加与以前的数据集(以花朵+啤酒或窗口的示例为例)非常不同的对象,应该会大大降低准确性。 您将需要训练更长的时间才能获得更好的准确性(但是您的准确性可能从未像以前那样提高)