在处理GAN框架时,如何解决“必须输入占位符张量的值”?

时间:2019-05-31 07:51:21

标签: python tensorflow keras

我正在尝试使用Keras(Tensorflow后端)实现GAN,以使图像着色。我的生成器有一个灰度输入图像,而我的鉴别器有灰度和彩色图像作为输入。

如何在没有错误的情况下训练发电机

  

“ InvalidArgumentError:必须输入占位符张量的值   'input_2'具有dtype浮点型,形状为[?,128,128,1] [[{{node   input_2}}]] [[{{nodemetrics / acc / Mean_2}}]]“

我尝试了不同的方法来创建组合模型(用于训练生成器),但没有成功。我正在使用Python 3.6.7和Keras 2.2.4。

import numpy as np

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, AveragePooling2D, Dense, Dropout, Flatten, Lambda, MaxPool2D, Conv2DTranspose, UpSampling2D, Concatenate, Add
from tensorflow.keras import optimizers
from keras.preprocessing import image

def combine_generator(gen1, gen2):
    while True:
        yield(gen1.next(), gen2.next())

def generator_model(input_img):   
    outputs = Conv2D(3, (1, 1), activation='sigmoid') (input_img)
    model = Model(inputs=[input_img], outputs=[outputs])
    return model

def discriminator_model(output_img, input_img):
    a1 = Concatenate()([output_img, input_img])
    f1 = Flatten()(a1)
    output = Dense(1, activation='sigmoid')(f1)
    model = Model(inputs=[output_img, input_img], outputs=[output])
    return model

def generator_containing_discriminator(input_img, generator, discriminator):
    goutput = generator(input_img)
    discriminator.trainable=False
    doutput = discriminator([goutput, input_img])
    model = Model(inputs=[input_img], outputs=[doutput])
    return model

seed = 123456
input_size = 128
batch_size = 8
learning_rate = 1e-3
optimizer = optimizers.Adam(lr=learning_rate)

dir_train_img = "flowers_train"
data_gen = dict(rescale=1./255)
image_datagen = image.ImageDataGenerator(**data_gen)

color_generator_train = image_datagen.flow_from_directory(dir_train_img, batch_size=batch_size, class_mode=None, target_size=(input_size, input_size), color_mode="rgb", seed=seed)
gray_generator_train = image_datagen.flow_from_directory(dir_train_img, batch_size=batch_size, class_mode=None, target_size=(input_size, input_size), color_mode="grayscale", seed=seed)
train_generator = combine_generator(color_generator_train, gray_generator_train)

dmodel = discriminator_model(Input((input_size, input_size, 3)), Input((input_size, input_size, 1)))
dmodel.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
dmodel.summary()

gmodel = generator_model(Input((input_size, input_size, 1)))
gmodel.summary()

gdmodel = generator_containing_discriminator(Input((input_size, input_size, 1)), gmodel, dmodel)
gdmodel.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
gdmodel.summary()

dmodel.trainable=False

train_batch = next(train_generator)
labels2 = np.array([1]*len(train_batch[1]))
gdmodel.train_on_batch(train_batch[1], labels2)

如摘要所示,input_2对应于鉴别器的灰度输入图像,但我不知道问题出在哪里。

1 个答案:

答案 0 :(得分:0)

当我从以下行中删除参数metrics = [“ accuracy”]时,问题消失了:

gdmodel.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])

我不知道为什么。