如何使Keras-GAN使用非标准输入形状?

时间:2019-03-24 19:01:00

标签: keras

我一直在使用这个Keras-GAN存储库,特别是它的pix2pix软件包来进行图像到图像的翻译。当输入形状为标准(256,256)时,它的效果很好,还可以与(28,28)MNIST数据一起使用。但是,当我尝试输入形状(196,208)时,出现以下ValueError:

  

“连接”层要求输入的形状与连接轴一致,但形状必须匹配。得到了输入形状:[[None,8,8,512),(None,7,7,512)]

对于某些情况,在下面的代码块中调用了Concatenate函数:

# Image input
d0 = Input(shape=self.img_shape)

# Downsampling
d1 = conv2d(d0, self.gf, bn=False)
d2 = conv2d(d1, self.gf*2)
d3 = conv2d(d2, self.gf*4)
d4 = conv2d(d3, self.gf*8)
d5 = conv2d(d4, self.gf*8)
d6 = conv2d(d5, self.gf*8)
d7 = conv2d(d6, self.gf*8)

# Upsampling
u1 = deconv2d(d7, d6, self.gf*8)
u2 = deconv2d(u1, d5, self.gf*8)
u3 = deconv2d(u2, d4, self.gf*8)
u4 = deconv2d(u3, d3, self.gf*4)
u5 = deconv2d(u4, d2, self.gf*2)
u6 = deconv2d(u5, d1, self.gf)

其中conv2d和deconv2d为:

def conv2d(layer_input, filters, f_size=4, bn=True):
 """Layers used during downsampling"""
 d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
    d = LeakyReLU(alpha=0.2)(d)
 if bn:
        d = BatchNormalization(momentum=0.8)(d)
 return d

def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
 """Layers used during upsampling"""
 u = UpSampling2D(size=2)(layer_input)
    u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
 if dropout_rate:
        u = Dropout(dropout_rate)(u)
    u = BatchNormalization(momentum=0.8)(u)
    u = Concatenate()([u, skip_input])
 return u

0 个答案:

没有答案