这是解卷积 - 对话生成对抗网络(DC-GAN)的一部分代码
discriminator.trainable = False
ganInput = Input(shape=(100,))
# getting the output of the generator
# and then feeding it to the discriminator
# new model = D(G(input))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(input=ganInput, output=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=Adam())
问题1 - 我不明白ganInput = Input(shape =(100,))行的作用。显然ganInput是一个变量,但什么是输入?这是一个功能吗?如果Input是一个函数,那么ganInput包含什么?
问题2 - Model
API的作用是什么?我在keras文档中读过但未能理解它在这里做了什么。
请询问您需要的任何进一步澄清/详细信息。
Keras与TensorFlow后端 完整的源代码: https://github.com/yashk2810/DCGAN-Keras/blob/master/DCGAN.ipynb
答案 0 :(得分:2)
行ganInput = Input(shape=(100,))
只是定义输入的形状
这是一个形状的张量(100,)
模型将包括计算给定输入的输出所需的所有层。对于多输入或多输出模型,您也可以使用列表:
model = Model(inputs=[ganInput1, ganInput2], outputs=[ganOutput1, ganOutput2, ganOutput3])
这意味着计算模型api需要的ganOutput1,ganOutput2,ganOutput3 输入图层ganInput1,ganInput2
这对于回溯是必要的,因此Model api具有计算输出所需的内容
此行加载mnist数据:(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
.... X_train
和Y_train
包含训练数据及其相应的目标值.... X_test
,{{ 1}}具有训练数据及其相应的目标值
Y_test