使用跳过连接时不会发生损失

时间:2017-08-31 07:55:30

标签: python-2.7 keras convolution loss deep-residual-networks

我试图在Keras中实现这篇论文:https://arxiv.org/pdf/1603.09056.pdf它使用带有跳过连接的Conv-Deconv来创建图像去噪网络。如果我在相应的Conv-Deconv层之间进行对称跳过连接,但是如果我在输入和输出之间添加连接(如文中所述),我的网络工作得很好,我的网络无法训练。是不是我不理解这篇论文?

  

"但是,我们的网络从输入中学习附加损坏,因为输入和网络输出之间存在跳过连接"

以下是论文中描述的网络:

enter image description here

这是我的网络:

input_img = Input(shape=(None,None,3))

############################
####### CONVOLUTIONS #######
############################

c1 = Convolution2D(64, (3, 3))(input_img)
a1 = Activation('relu')(c1)

c2 = Convolution2D(64, (3, 3))(a1)
a2 = Activation('relu')(c2)

c3 = Convolution2D(64, (3, 3))(a2)
a3 = Activation('relu')(c3)

c4 = Convolution2D(64, (3, 3))(a3)
a4 = Activation('relu')(c4)

c5 = Convolution2D(64, (3, 3))(a4)
a5 = Activation('relu')(c5)

############################
###### DECONVOLUTIONS ######
############################

d1 = Conv2DTranspose(64, (3, 3))(a5)
a6 = Activation('relu')(d1)

m1 = add([a4, a6])
a7 = Activation('relu')(m1)

d2 = Conv2DTranspose(64, (3, 3))(a7)
a8 = Activation('relu')(d2)

m2 = add([a3, a8])
a9 = Activation('relu')(m2)

d3 = Conv2DTranspose(64, (3, 3))(a9)
a10 = Activation('relu')(d3)

m3 = add([a2, a10])
a11 = Activation('relu')(m3)

d4 = Conv2DTranspose(64, (3, 3))(a11)
a12 = Activation('relu')(d4) 

m4 = add([a1, a12])
a13 = Activation('relu')(m4)

d5 = Conv2DTranspose(3, (3, 3))(a13)
a14 = Activation('relu')(d5)

m5 = add([input_img, a14]) # Everything goes well without this line
out = Activation('relu')(m5)

model = Model(input_img, out) 
model.compile(optimizer='adam', loss='mse')

如果我训练它,这就是我得到的:

Epoch 1/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 - val_loss: 0.0015
Epoch 2/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 -  val_loss: 0.0015
Epoch 3/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 -   val_loss: 0.0015
Epoch 4/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 - val_loss: 0.0015
Epoch 5/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

我的网络出了什么问题?

1 个答案:

答案 0 :(得分:0)

激活'relu' 永远不会返回负值

由于您要将输入添加到输出(a14),您需要"去噪" (删除噪音),当然预计输出(a14)包含正面和负面值。 (你想要使亮点变暗并减轻黑斑)。

因此,a14中的激活不能是'relu'。它必须是积极的和消极的,能够达到噪音的范围。可能是'tanh'或自定义激活。如果您的输入从0变为1,'tanh'可能是最佳选择。

(不确定以前的图层,也许其中一些使用'tanh'会让这个过程更容易)

有时那些长卷积网络确实会卡住,我在这里训练U-net,并且需要一段时间才能使它收敛。当它被卡住时,有时再次建立模型(新的重量初始化)并尝试过来会更好。

请在此处查看详细信息:How to build a multi-class convolutional neural network with Keras