VAE建模时,keras中的keras InvalidArgumentError

时间:2018-11-02 00:24:46

标签: python keras

我们正在使用keras api实现vae建模。 1epoch结束时,显示以下错误消息。

  

InvalidArgumentError:不兼容的形状:[1920]与[3808] [/节点:   custom_variational_layer_10 / logistic_loss / mul = Mul [T = DT_FLOAT,   _class = [“ loc:@train ... ad / Reshape”],_device =:CPU:0“](custom_variational_layer_10 /日志,custom_variational_layer_10 /   重塑)]]

我认为,它似乎从第一个Conv2D更改为难以与原始类型匹配的形状。

例如,当输入(755,1,199,1)的输入数据时,将其转换为119-> 60 解码时从60切换到199似乎有误。

(755,1,24,1),数据转换为24-> 12,然后正常处理为12-> 24。

出什么问题了?

input_img = keras.Input(shape=img_shape)

x = layers.Conv2D(filter_size, 3,
                  padding='same', activation='relu',)(input_img)
x = layers.Conv2D(filter_size*2, 3,
                  padding='same', activation='relu',
                  strides=(1, 2))(x)
x = layers.Conv2D(filter_size*2, 3,
                  padding='same', activation='relu')(x)
x = layers.Conv2D(filter_size*2, 3,
                  padding='same', activation='relu')(x)
shape_before_flattening = K.int_shape(x)

x = layers.Flatten()(x)
x = layers.Dense(filter_size, activation='relu')(x)

z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)



def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_var])

将z注入Input

decoder_input = layers.Input(K.int_shape(z)[1:])

上采样到

x = layers.Dense(np.prod(shape_before_flattening[1:]),
                 activation='relu')(decoder_input)

将z的大小更改为与属性图大小相同的属性图 在编码器模型的最后一个Flatten层之前

x = layers.Reshape(shape_before_flattening[1:])(x)

使用Conv2DTranspose层和Conv2D层将z解码为属性 与原始输入图像具有相同尺寸的地图

x = layers.Conv2DTranspose(filter_size, 3,
                           padding='same', activation='relu',
                           strides=(1, 2))(x)
x = layers.Conv2D(1, 3,
                  padding='same', activation='sigmoid')(x)   

属性映射的大小与源输入的大小相同

创建解码器模型对象     解码器=型号(decoder_input,x)

z_decoded = decoder(z)

0 个答案:

没有答案