我正在训练一个简单的对抗图像,以打破一个预先训练的模型。但是,我在fit()过程中获得的结果与在同一输入(恒定输入)上调用预报()不同。
model.trainable = False
gan = Sequential()
gan.add(Dense( 256 * 256 * 3, use_bias=False, input_shape=(1,)))
gan.add(Reshape((256, 256, 3)))
gan.add(model)
gan.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_2 (Dense) (None, 196608) 196608
_________________________________________________________________
reshape_2 (Reshape) (None, 256, 256, 3) 0
_________________________________________________________________
sequential_1 (Sequential) (None, 2) 24952610
=================================================================
Total params: 25,149,218
Trainable params: 196,608
Non-trainable params: 24,952,610
_________________________________________________________________
img = img.reshape(256, 256, 3)
def custom_loss(layer):
# Create a loss function that adds the MSE loss to the mean of all squared activations of a specific layer
def loss(y_true,y_pred):
y_true = K.print_tensor(y_true, message='y_true = ')
y_pred = K.print_tensor(y_pred, message='y_pred = ')
label_diff = K.square(y_pred - y_true)
return K.mean(label_diff)
# Return a function
return loss
gan.compile(optimizer='adam',
loss=custom_loss(gan.layers[1]), # Call the loss function with the selected layer
metrics=['accuracy'])
x = np.ones((1,1))
goal = np.array([0, 1])
y = goal.reshape((1,2))
gan.fit(x, y, epochs=300, verbose=1)
在fit()期间,损耗下降得很好
Epoch 1/300
1/1 [==============================] - 5s 5s/step - loss: 0.9950 - acc: 0.0000e+00
...
Epoch 300/300
1/1 [==============================] - 0s 46ms/step - loss: 0.0045 - acc: 1.0000
在后端,y_pred和y_true也是正确的
......
y_true = [[0 1]]
y_pred = [[0.100334756 0.899665236]]
y_true = [[0 1]]
y_pred = [[0.116679631 0.883320332]]
y_true = [[0 1]]
y_pred = [[0.0832592845 0.916740656]]
y_true = [[0 1]]
y_pred = [[0.098835744 0.901164234]]
y_true = [[0 1]]
y_pred = [[0.0979194269 0.902080595]]
y_true = [[0 1]]
y_pred = [[0.057831794 0.942168236]]
y_true = [[0 1]]y_pred = [[0.0760448873 0.923955142]]
y_true = [[0 1]]
y_pred = [[0.041532293 0.958467722]]
y_true = [[0 1]]
y_pred = [[0.0667938739 0.933206141]]
print(gan.predict(x))
给予
[[0.99923825 0.00076174]]
尝试使用预训练的Resnet和InceptionV3,并且都遇到相同的问题。附件是model.summary()
初始:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
inception_v3 (Model) (None, None, None, 2048) 21802784
_________________________________________________________________
global_average_pooling2d_1 ( (None, 2048) 0
_________________________________________________________________
dense_1 (Dense) (None, 1024) 2098176
_________________________________________________________________
dropout_1 (Dropout) (None, 1024) 0
_________________________________________________________________
dense_2 (Dense) (None, 1024) 1049600
_________________________________________________________________
dropout_2 (Dropout) (None, 1024) 0
_________________________________________________________________
dense_3 (Dense) (None, 2) 2050
=================================================================
Total params: 24,952,610
Trainable params: 14,264,706
Non-trainable params: 10,687,904
_________________________________________________________________
对于Resnet:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 262, 262, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 128, 128, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 128, 128, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 128, 128, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D) (None, 130, 130, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 64, 64, 64) 0 pool1_pad[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 64, 64, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 64, 64, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 64, 64, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 64, 64, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 64, 64, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 64, 64, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 64, 64, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 64, 64, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 64, 64, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 64, 64, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 64, 64, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 64, 64, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 64, 64, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 64, 64, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 64, 64, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 64, 64, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 64, 64, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 64, 64, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 64, 64, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 64, 64, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 64, 64, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 64, 64, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 64, 64, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 64, 64, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 64, 64, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 64, 64, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 64, 64, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 64, 64, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 64, 64, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 64, 64, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 64, 64, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 64, 64, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 32, 32, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 32, 32, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 32, 32, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 32, 32, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 32, 32, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 32, 32, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 32, 32, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 32, 32, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 32, 32, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 32, 32, 128) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 32, 32, 512) 0 bn3b_branch2c[0][0]
activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 32, 32, 512) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 32, 32, 128) 65664 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 32, 32, 128) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 32, 32, 128) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 32, 32, 512) 0 bn3c_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 32, 32, 512) 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 32, 32, 128) 65664 activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 32, 32, 128) 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 32, 32, 128) 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 32, 32, 512) 0 bn3d_branch2c[0][0]
activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 32, 32, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 16, 16, 256) 131328 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 16, 16, 256) 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 16, 16, 256) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 16, 16, 256) 590080 activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 16, 16, 256) 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 16, 16, 256) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 16, 16, 1024) 263168 activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 16, 16, 1024) 525312 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 16, 16, 1024) 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 16, 16, 1024) 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 16, 16, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 16, 16, 1024) 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 16, 16, 256) 262400 activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 16, 16, 256) 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, 16, 16, 256) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 16, 16, 256) 590080 activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 16, 16, 256) 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, 16, 16, 256) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 16, 16, 1024) 263168 activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 16, 16, 1024) 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 16, 16, 1024) 0 bn4b_branch2c[0][0]
activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, 16, 16, 1024) 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 16, 16, 256) 262400 activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 16, 16, 256) 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, 16, 16, 256) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 16, 16, 256) 590080 activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 16, 16, 256) 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, 16, 16, 256) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 16, 16, 1024) 263168 activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 16, 16, 1024) 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 16, 16, 1024) 0 bn4c_branch2c[0][0]
activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, 16, 16, 1024) 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 16, 16, 256) 262400 activation_31[0][0]
__________________________________________________________________________________________________
... omitted ...
Total params: 23,593,859
Trainable params: 23,540,739
Non-trainable params: 53,120
__________________________________________________________________________________________________
答案 0 :(得分:3)
那些经过预训练的模型包含BatchNormalization
层。
预计它们在训练和测试之间的表现会有所不同(Dropout
层也是如此,但差异不会那么大)。
训练中的BatchNormalization
将使用当前批次的均值和方差进行归一化,它还将对一些批次可能不能代表完整数据集这一事实应用一些统计补偿。
但是在评估过程中,BatchNormalization
将使用在训练期间收集的调整后的值进行均值和方差。 (在这种情况下,是在“预培训”期间收集的,而不是您的培训)
要使BatchNormalization
正常工作,您需要预训练模型的输入与模型原始训练数据的范围相同。否则,您必须使BatchNormalization
层可训练,以便根据数据调整均值和方差。
但是您的培训需要大量的批处理量以及实际数据才能正确进行培训。
用于训练图像的提示。
在导入预训练模型的同一模块中,可以导入preprocess_input
函数。给它加载一些keras.preprocessing.images.load_img
图像,看看模型的预期范围是多少。
使用ImageDataGenerator
时,可以传递此preprocess_input
函数,以便生成器为您提供期望的数据。