我正在训练卷积自动编码器,但我无法降低损耗,希望有人可以指出一些可能的改进。
我有1024x1024的灰度图像(我也尝试过512x512的灰度图像),我希望将其压缩以进行无监督的群集。我的完整模型在下面,但它遵循一个非常基本的模式,即几个Conv2D图层(带有最大池),然后是一个密集层,然后将其重新塑形并重新设置为原始图像大小。
到目前为止我已经尝试过的一些事情:
1)我发现mse作为损失函数比二元交叉熵更好,因为像素亮度值远没有均匀分布(二元交叉熵卡住了,将所有值分配为1,误差很小,但是没有用)。
2)如果我只去除中间的密集层并稍稍压缩图像,则可以轻松实现极低的误差和几乎完美的图像重建(至少在我看来)。这是相当明显的,但是我想它表明我没有犯某种错误,导致输出无意义。
3)我的损失并没有真正低于0.02-0.03。不过,在0.025左右时,图像已足够重建,因此很明显输出来自输入,而不是某种随机噪声(例如使每个像素具有相同的强度或某种强度)。我认为将其降低到0.01以下就足够了。我的最低值(虽然是我的数据中稍容易的一个子集)为0.018,当我在热图中绘制编码值时,可以看到样本中明显的聚类。
4)当我的中间密集层使用ReLU激活时,我得到了很多濒死的ReLU,这使得它对于最终的聚类不太有用。我改用tanh。我还发现“ he_normal”作为密集层的初始化效果更好。
5)在中间添加更密集的层似乎根本没有帮助。
6)改变编码器的形状(使它从每层更少的内核变为更多的内核)也无济于事,尽管我知道传统上这就是卷积自动编码器的外观。
这是完整的模型(来自model.summary()的输出
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_4 (InputLayer) (None, 1024, 1024, 1) 0
_________________________________________________________________
conv2d_40 (Conv2D) (None, 1024, 1024, 128) 1280
_________________________________________________________________
max_pooling2d_19 (MaxPooling (None, 512, 512, 128) 0
_________________________________________________________________
batch_normalization_4 (Batch (None, 512, 512, 128) 512
_________________________________________________________________
conv2d_41 (Conv2D) (None, 512, 512, 64) 73792
_________________________________________________________________
max_pooling2d_20 (MaxPooling (None, 256, 256, 64) 0
_________________________________________________________________
conv2d_42 (Conv2D) (None, 256, 256, 32) 18464
_________________________________________________________________
max_pooling2d_21 (MaxPooling (None, 128, 128, 32) 0
_________________________________________________________________
conv2d_43 (Conv2D) (None, 128, 128, 16) 4624
_________________________________________________________________
max_pooling2d_22 (MaxPooling (None, 64, 64, 16) 0
_________________________________________________________________
conv2d_44 (Conv2D) (None, 64, 64, 8) 1160
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 32, 32, 8) 0
_________________________________________________________________
flatten_4 (Flatten) (None, 8192) 0
_________________________________________________________________
dense_5 (Dense) (None, 512) 4194816
_________________________________________________________________
reshape_4 (Reshape) (None, 8, 8, 8) 0
_________________________________________________________________
up_sampling2d_22 (UpSampling (None, 16, 16, 8) 0
_________________________________________________________________
conv2d_45 (Conv2D) (None, 16, 16, 16) 1168
_________________________________________________________________
up_sampling2d_23 (UpSampling (None, 32, 32, 16) 0
_________________________________________________________________
conv2d_46 (Conv2D) (None, 32, 32, 16) 2320
_________________________________________________________________
up_sampling2d_24 (UpSampling (None, 64, 64, 16) 0
_________________________________________________________________
conv2d_47 (Conv2D) (None, 64, 64, 32) 4640
_________________________________________________________________
up_sampling2d_25 (UpSampling (None, 128, 128, 32) 0
_________________________________________________________________
conv2d_48 (Conv2D) (None, 128, 128, 64) 18496
_________________________________________________________________
up_sampling2d_26 (UpSampling (None, 256, 256, 64) 0
_________________________________________________________________
conv2d_49 (Conv2D) (None, 256, 256, 128) 73856
_________________________________________________________________
up_sampling2d_27 (UpSampling (None, 512, 512, 128) 0
_________________________________________________________________
conv2d_50 (Conv2D) (None, 512, 512, 128) 147584
_________________________________________________________________
up_sampling2d_28 (UpSampling (None, 1024, 1024, 128) 0
_________________________________________________________________
conv2d_51 (Conv2D) (None, 1024, 1024, 1) 1153
=================================================================
Total params: 4,543,865
Trainable params: 4,543,609
Non-trainable params: 256
答案 0 :(得分:0)
您的损失功能可能是问题所在。在网络的Logit输出上使用BCE。应该解决问题。
tf.keras.losses.BinaryCrossentropy(from_logits=True)
注意:从嵌入中进行重构时,请向其中添加一个S型函数。
z = encoder(x)
x_hat_raw = decoder(z)
reconstruction = sigmoid(x_hat_raw)
现在应该训练好!