卷积自动编码器keras造成的高损失

时间:2019-09-24 21:48:21

标签: keras conv-neural-network convolution autoencoder

我正在训练卷积自动编码器,但我无法降低损耗,希望有人可以指出一些可能的改进。

我有1024x1024的灰度图像(我也尝试过512x512的灰度图像),我希望将其压缩以进行无监督的群集。我的完整模型在下面,但它遵循一个非常基本的模式,即几个Conv2D图层(带有最大池),然后是一个密集层,然后将其重新塑形并重新设置为原始图像大小。

到目前为止我已经尝试过的一些事情:

1)我发现mse作为损失函数比二元交叉熵更好,因为像素亮度值远没有均匀分布(二元交叉熵卡住了,将所有值分配为1,误差很小,但是没有用)。

2)如果我只去除中间的密集层并稍稍压缩图像,则可以轻松实现极低的误差和几乎完美的图像重建(至少在我看来)。这是相当明显的,但是我想它表明我没有犯某种错误,导致输出无意义。

3)我的损失并没有真正低于0.02-0.03。不过,在0.025左右时,图像已足够重建,因此很明显输出来自输入,而不是某种随机噪声(例如使每个像素具有相同的强度或某种强度)。我认为将其降低到0.01以下就足够了。我的最低值(虽然是我的数据中稍容易的一个子集)为0.018,当我在热图中绘制编码值时,可以看到样本中明显的聚类。

4)当我的中间密集层使用ReLU激活时,我得到了很多濒死的ReLU,这使得它对于最终的聚类不太有用。我改用tanh。我还发现“ he_normal”作为密集层的初始化效果更好。

5)在中间添加更密集的层似乎根本没有帮助。

6)改变编码器的形状(使它从每层更少的内核变为更多的内核)也无济于事,尽管我知道传统上这就是卷积自动编码器的外观。

这是完整的模型(来自model.summary()的输出

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         (None, 1024, 1024, 1)     0       
_________________________________________________________________
conv2d_40 (Conv2D)           (None, 1024, 1024, 128)   1280      
_________________________________________________________________
max_pooling2d_19 (MaxPooling (None, 512, 512, 128)     0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 512, 512, 128)     512       
_________________________________________________________________
conv2d_41 (Conv2D)           (None, 512, 512, 64)      73792
_________________________________________________________________
max_pooling2d_20 (MaxPooling (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_42 (Conv2D)           (None, 256, 256, 32)      18464
_________________________________________________________________
max_pooling2d_21 (MaxPooling (None, 128, 128, 32)      0         
_________________________________________________________________
conv2d_43 (Conv2D)           (None, 128, 128, 16)      4624      
_________________________________________________________________
max_pooling2d_22 (MaxPooling (None, 64, 64, 16)        0         
_________________________________________________________________
conv2d_44 (Conv2D)           (None, 64, 64, 8)         1160      
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 32, 32, 8)         0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 8192)              0         
_________________________________________________________________
dense_5 (Dense)              (None, 512)               4194816   
_________________________________________________________________
reshape_4 (Reshape)          (None, 8, 8, 8)           0         
_________________________________________________________________
up_sampling2d_22 (UpSampling (None, 16, 16, 8)         0         
_________________________________________________________________
conv2d_45 (Conv2D)           (None, 16, 16, 16)        1168      
_________________________________________________________________
up_sampling2d_23 (UpSampling (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_46 (Conv2D)           (None, 32, 32, 16)        2320      
_________________________________________________________________
up_sampling2d_24 (UpSampling (None, 64, 64, 16)        0         
_________________________________________________________________
conv2d_47 (Conv2D)           (None, 64, 64, 32)        4640      
_________________________________________________________________
up_sampling2d_25 (UpSampling (None, 128, 128, 32)      0         
_________________________________________________________________
conv2d_48 (Conv2D)           (None, 128, 128, 64)      18496
_________________________________________________________________
up_sampling2d_26 (UpSampling (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_49 (Conv2D)           (None, 256, 256, 128)     73856     
_________________________________________________________________
up_sampling2d_27 (UpSampling (None, 512, 512, 128)     0         
_________________________________________________________________
conv2d_50 (Conv2D)           (None, 512, 512, 128)     147584
_________________________________________________________________
up_sampling2d_28 (UpSampling (None, 1024, 1024, 128)   0
_________________________________________________________________
conv2d_51 (Conv2D)           (None, 1024, 1024, 1)     1153
=================================================================
Total params: 4,543,865
Trainable params: 4,543,609
Non-trainable params: 256

1 个答案:

答案 0 :(得分:0)

您的损失功能可能是问题所在。在网络的Logit输出上使用BCE。应该解决问题。

  1. 使用:tf.keras.losses.BinaryCrossentropy(from_logits=True)
  2. 从编码器和解码器的最后一层删除激活功能(编码器的最后一个密集层和解码器的最后一个Conv层应该没有激活)。

注意:从嵌入中进行重构时,请向其中添加一个S型函数。

z = encoder(x)
x_hat_raw = decoder(z)
reconstruction = sigmoid(x_hat_raw)

现在应该训练好!