DCGAN - 理解代码

时间:2017-12-26 10:12:25

标签: python-3.x tensorflow keras conv-neural-network

这是解卷积 - 对话生成对抗网络(DC-GAN)的一部分代码

discriminator.trainable = False
ganInput = Input(shape=(100,))
# getting the output of the generator
# and then feeding it to the discriminator
# new model = D(G(input))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(input=ganInput, output=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

问题1 - 我不明白ganInput = Input(shape =(100,))行的作用。显然ganInput是一个变量,但什么是输入?这是一个功能吗?如果Input是一个函数,那么ganInput包含什么?

问题2 - Model API的作用是什么?我在keras文档中读过但未能理解它在这里做了什么。

请询问您需要的任何进一步澄清/详细信息。

Keras与TensorFlow后端 完整的源代码: https://github.com/yashk2810/DCGAN-Keras/blob/master/DCGAN.ipynb

1 个答案:

答案 0 :(得分:2)

ganInput = Input(shape=(100,))只是定义输入的形状 这是一个形状的张量(100,)

模型将包括计算给定输入的输出所需的所有层。对于多输入或多输出模型,您也可以使用列表:

model = Model(inputs=[ganInput1, ganInput2], outputs=[ganOutput1, ganOutput2, ganOutput3])

这意味着计算模型api需要的ganOutput1,ganOutput2,ganOutput3 输入图层ganInput1,ganInput2

这对于回溯是必要的,因此Model api具有计算输出所需的内容

此行加载mnist数据:(X_train, Y_train), (X_test, Y_test) = mnist.load_data() .... X_trainY_train包含训练数据及其相应的目标值.... X_test,{{ 1}}具有训练数据及其相应的目标值

Y_test