Python:针对人脸反欺骗问题的预训练VGG人脸模型

时间:2019-03-31 08:33:24

标签: python tensorflow keras pre-trained-model vgg-net

我正在尝试通过使用预先训练的模型(例如在ImageNet上训练的VGG)来解决面部反欺骗问题。我在哪里需要检索功能?在哪一层之后?更具体地说,是否可以将最后一个完全连接层的输出从2622更改为2,就像在面对反欺骗问题中一样,我们有两个类(真实/伪造)?

实际上,使用预训练的VGG人脸模型(在ImageNet上进行训练)是否有效解决人脸反欺骗问题?并请任何教程或GitHub代码帮助我在Python中实现此目标吗?

1 个答案:

答案 0 :(得分:0)

也许来不及回答,但总比没有好。

如果样本太少或太多,这取决于您的数据集。通常,当您的数据量有限和/或在提取样品的大多数特征以提高准确性时要避免过度拟合时,建议使用预训练模型。 如果您使用的是Keras,请尝试使用VGG16:

conv_net = VGG16(weights="imagenet", 
                 include_top=False,
                 input_shape=(150, 150, 3)) # Change the shape accordingly

它为您提供了这样的图层堆栈:

Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 150, 150, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 150, 150, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 150, 150, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 75, 75, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 75, 75, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 75, 75, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 37, 37, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 37, 37, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 37, 37, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 37, 37, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 18, 18, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 18, 18, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 18, 18, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 18, 18, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 9, 9, 512)         0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 4, 4, 512)         0         
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0

要使用此模型,您有两种选择,一种是仅使用此模型提取要素并将其保存在磁盘上,然后在下一步中创建密集连接的层并将上一步的输出提供给模型。这种方法比我将要解释的下一种方法快得多,但是唯一的缺点是您不能使用数据增强。这是使用predict的{​​{1}}方法提取特征的方法:

conv_net

第二个选择是将密集连接的模型附加到VGG模型的顶部,冻结features_batch = conv_base.predict(inputs_batch) # Save the features in a tensor and feed them to the Dense Layer after all has been extracted 层并将数据正常地馈送到网络,这样您就可以使用数据增强功能,但仅当您使用可以访问功能强大的GPU或云。这是有关如何冻结和连接VGG顶部的Dense层的代码:

conv_net

您甚至可以通过解冻#codes adopted from "Deep Learning with Python" book from keras import models from keras import layers conv_base.trainable = False model = models.Sequential() model.add(conv_base) model.add(layers.Flatten()) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dense(1, activation='sigmoid')) 的一层来适应数据来微调模型。您可以通过以下方法冻结除一层以外的所有层:

conv_net

希望它可以帮助您入门。