使用Keras和WGAN进行InvalidArgumentError

时间:2018-03-13 11:36:03

标签: python tensorflow keras

我试图建立一个评论家网络作为Wasserstein距离生成对抗网络的一部分,使用最新的Keras软件包和Tensorflow后端和python 3.6。这是网络架构的代码:

def build_critic():
    critic = Sequential(name='critic')
    critic.add(Conv2D(64, (3, 3), padding='same', 
               kernel_initializer='he_normal', input_shape= 
               (48,48,1),data_format='channels_last'))
    critic.add(LeakyReLU())
    critic.add(Conv2D(128, (3, 3), padding='same', 
               kernel_initializer='he_normal'))
    critic.add(LeakyReLU())
    critic.add(Conv2D(128, (3, 3), padding='same', 
               kernel_initializer='he_normal'))
    critic.add(LeakyReLU())
    critic.add(Conv2D(256, (3, 3), padding='same', 
               kernel_initializer='he_normal'))
    critic.add(LeakyReLU())
    critic.add(GlobalMaxPooling2D(data_format='channels_last'))
    critic.add(Dense(100))
    critic.add(LeakyReLU())
    critic.add(Dense(1))
    return critic

尝试编译网络时,出现错误:

  

InvalidArgumentError(参见上面的回溯):矩阵大小不兼容:在[0]中:[256,48],在[1]中:[256,100]        [[节点:critic_2 / dense_2 / MatMul = MatMul [T = DT_FLOAT,transpose_a = false,transpose_b = false,_device =" / job:localhost / replica:0 / task:0 / device:CPU:0&#34 ;](critic_2 / global_max_pooling2d_1 / Max,dense_2 / kernel / read)]]

我得到了错误的抱怨。为了将两个大小为AxB和CxD的矩阵相乘,B必须等于C,而这里B = 48且C = 256。我怎样才能设置C?或者我还应该考虑其他问题吗?

以下是发电机组和评论家网络的网络摘要:

图层(类型)输出形状参数#

dense_1(密集)(无,1024)1049600

reshape_1(重塑)(无,1,1,1024)0

up_sampling2d_1(UpSampling2(无,3,3,405)0

batch_normalization_1(批量(无,3,3,405)4096

up_sampling2d_2(UpSampling2(无,12,12,1024)0

batch_normalization_2(批量(无,12,12,1024)4096

up_sampling2d_3(UpSampling2(无,48,48,1024)0

conv2d_1(Conv2D)(无,48,48,128)1179776

batch_normalization_3(批量(无,48,48,128)512

conv2d_2(Conv2D)(无,48,48,128)147584

batch_normalization_4(批量(无,48,48,128)512

conv2d_3(Conv2D)(无,48,48,256)295168

batch_normalization_5(批量(无,48,48,256)1024

conv2d_4(Conv2D)(无,48,48,1)2305

总参数:2,684,673 可训练的参数:2,679,553 不可训练的参数:5,120

图层(类型)输出形状参数#

conv2d_5(Conv2D)(无,48,48,64)640

leaky_re_lu_1(LeakyReLU)(无,48,48,64)0

conv2d_6(Conv2D)(无,48,48,128)73856

leaky_re_lu_2(LeakyReLU)(无,48,48,128)0

conv2d_7(Conv2D)(无,48,48,128)147584

leaky_re_lu_3(LeakyReLU)(无,48,48,128)0

conv2d_8(Conv2D)(无,48,48,256)295168

leaky_re_lu_4(LeakyReLU)(无,48,48,256)0

global_max_pooling2d_1(全球(无,256)0

dense_2(密集)(无,100)25700

leaky_re_lu_5(LeakyReLU)(无,100)0

dense_3(密集)(无,1)101

总参数:543,049 可训练的参数:543,049 不可训练的参数:0

0 个答案:

没有答案