如何从顶部和底部拆分resnet50模型?

时间:2019-01-15 21:53:37

标签: tensorflow keras deep-learning resnet

我正在将kerasinclude_top=False一起使用经过预训练的模型,但是我也想从top中删除一个重新块,并从bottom中删除一个。对于vgg网络,它很简单,因为层中有直接的链接,但是在resnet中,由于跳过连接,体系结构很复杂,因此直接方法不太适合。

有人可以推荐任何资源或脚本吗?

renet = tf.keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet')

1 个答案:

答案 0 :(得分:1)

如果您不正确理解,则要消除第一个块和最后一个块。

我的建议是使用resnet.summary()来可视化模型的所有名称。甚至更好,如果您有一个张量板可以清楚地看到它们之间的关系。

尽管您可以知道“残差网络”中某个块的完成是一个总和,并且仅在激活之后即可完成。激活将是您想要获得的层。

块的名称与res2a相似...数字2表示块,字母“ subblock”。

基于Resnet50架构:

enter image description here

如果要删除第一个剩余块,则必须寻找res2c的结尾。在这种情况下,我发现了这一点:

activation_57 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a [0] [0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_57 [0] [0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b [0] [0]
__________________________________________________________________________________________________
activation_58 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b [0] [0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_58 [0] [0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c [0] [0]
__________________________________________________________________________________________________
add_19 (Add) (None, 56, 56, 256) 0 bn2c_branch2c [0] [0]
                                                                 activation_56 [0] [0]
__________________________________________________________________________________________________
activation_59 (Activation) (None, 56, 56, 256) 0 add_19 [0] [0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_59 [0] [0]

输入层是res3a_branch2a。这种形式我跳过了残差的第一块。

activation_87 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_87[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_88 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_88[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_29 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                 activation_86[0][0]              
__________________________________________________________________________________________________
activation_89 (Activation)      (None, 14, 14, 1024) 0           add_29[0][0]   

如果我想去除残差的最后一块,则应该寻找res4的结尾。 Thaat是Activation_89。

进行这些切割后,我们将得到以下模型:

enter image description here

resnet_cut = Model(inputs=resnet.get_layer('res3a_branch2a'), outputs=resnet.get_layer('activation_89'))