我正在尝试重新训练VGG16以对Lego图像进行分类。但是,我的模型精度较低(在20%之间)。我究竟做错了什么?也许FC的编号有误,或者是我的ImageDataGenerator。我大约。每班2k张图像,共6个班级。
我如何创建模型:
def vgg16Model(self,image_shape,num_classes):
model_VGG16 = VGG16(include_top = False, weights = None)
model_input = Input(shape = image_shape, name = 'input_layer')
output_VGG16_conv = model_VGG16(model_input)
#Init of FC layers
x = Flatten(name='flatten')(output_VGG16_conv)
x = Dense(256, activation = 'relu', name = 'fc1')(x)
output_layer = Dense(num_classes,activation='softmax',name='output_layer')(x)
vgg16 = Model(inputs = model_input, outputs = output_layer)
vgg16.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
vgg16.summary()
return vgg16
我正在创建ImageDataGenerator并进行培训:
path = "real_Legos_images/trainable_classes"
evaluate_path = "real_Legos_images/evaluation"
NN = NeuralNetwork()
gen = ImageDataGenerator(rotation_range=40, width_shift_range=0.02, shear_range=0.02,height_shift_range=0.02, horizontal_flip=True, fill_mode='nearest')
train_generator = gen.flow_from_directory(os.path.abspath(os.path.join(path)),
target_size = (224,224), color_mode = "rgb", batch_size = 16, class_mode='categorical')
validation_generator = gen.flow_from_directory(os.path.abspath(os.path.join(evaluate_path)),
target_size = (224,224), color_mode = "rgb", batch_size = 16, class_mode='categorical')
STEP_SIZE_TRAIN = train_generator.n//train_generator.batch_size
num_classes = len(os.listdir(os.path.abspath(os.path.join(path))))
VGG16 = NN.vgg16Model((224, 224, 3), num_classes)
VGG16.save_weights('weights.h5')
VGG16.fit_generator(train_generator, validation_data = validation_generator, validation_steps = validation_generator.n//validation_generator.batch_size,
steps_per_epoch = STEP_SIZE_TRAIN, epochs = 50)
答案 0 :(得分:0)
带有参数VGG16
的{{1}}模型将返回512个维特征图。通常,我们应该先在其后添加一个include_top = False
或GlobalAveragePooling2D
层,然后将其平整为一维数组。否则,您将得到一个无法容纳的数组。
答案 1 :(得分:0)
您已将VGG的weight属性设置为“ None”,这意味着您的网络是使用随机权重初始化的。这意味着您没有使用预先训练的重量。因此,我建议尝试将权重设置为“ imagenet”,以便可以使用其权重已在imagenet数据集上预先训练的VGG网络:
model_VGG16 = VGG16(include_top=False, weights='imagenet')