检查输入时出错:预期输入_2的形状为(250,250,3),但数组的形状为(200,200,3)训练GAN时

时间:2020-08-07 01:23:00

标签: numpy tensorflow keras deep-learning generative-adversarial-network

变量x_train的形状为(1000,100,100,3),y_train的形状为(1000,250,250,3)。
错误发生在 d_loss_real = D.train_on_batch(imgy,有效)行上。我曾尝试更改D的输入大小,但是当我将其更改为(200,200,3)时,它给了我相同的错误,但是现在可以切换期望的形状和输入形状。

epochs = 1
batch_size = 10
sample_interval = 1

with open('X.data', 'rb') as f:
    X_train = np.array(pickle.load(f))/255
with open('Y.data', 'rb') as f:
    Y_train = np.array(pickle.load(f))/255
inputs = tf.keras.Input(shape=(100, 100, 3))
x = Conv2D(8, (10, 10), activation='relu', padding='same')(inputs)
x = UpSampling2D((2.5, 2.5))(x)
x = Conv2D(8, (10, 10), activation='relu', padding='same')(x)
x = Conv2D(3, (10, 10), activation='relu', padding='same')(x)
G = tf.keras.Model(inputs=inputs, outputs=x)

G.summary()

inputs = tf.keras.Input(shape=(250, 250, 3))
x = Conv2D(8, (10, 10), activation='relu', padding='same')(inputs)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (10, 10), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (10, 10), activation='relu', padding='same')(x)
x = MaxPooling2D((4, 4), padding='same')(x)
x = Conv2D(1, (10, 10), activation='relu', padding='same')(x)
x = AveragePooling2D((32, 32), padding='same')(x)

D = tf.keras.Model(inputs=inputs, outputs=x)
D.compile(loss='binary_crossentropy', optimizer='adam')


z =tf.keras. Input(shape=(100, 100, 3))
img = G(z)
D.trainable = False
validity = D(img)
combined = tf.keras.Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='adam')

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgx = X_train[idx]
    imgy = Y_train[idx]

    g_loss = combined.train_on_batch(imgx, valid)
    
    gen_imgs = G.predict(imgx)
    
    d_loss_real = D.train_on_batch(imgy, valid)
    d_loss_fake = D.train_on_batch(gen_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

0 个答案:

没有答案