这个问题是关于在图像上使用TensorFlow的GAN的实现示例。
我已经摘录了我认为足以提供上下文的代码部分reference for full code。在下面的代码中,它定义了一个鉴别器函数,从松散的角度来看,它可以被认为是一个典型的卷积神经网络。可以看出,鉴别器操作取决于y_dim
,有人可以帮助解释什么是y_dim
吗?看Args
,我仍然对y_dim
的定义感到困惑。
class DCGAN(object):
def __init__(self, sess, image_size=108, is_crop=True,
batch_size=64, sample_size=64, output_size=64,
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
checkpoint_dir=None, sample_dir=None):
"""
Args:
sess: TensorFlow session
batch_size: The size of batch. Should be specified before training.
output_size: (optional) The resolution in pixels of the images. [64]
y_dim: (optional) Dimension of dim for y. [None]
z_dim: (optional) Dimension of dim for Z. [100]
gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
"""
...
def discriminator(self, image, y=None, reuse=False):
if reuse:
tf.get_variable_scope().reuse_variables()
if not self.y_dim:
h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim * 2, name='d_h1_conv')))
h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim * 4, name='d_h2_conv')))
h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim * 8, name='d_h3_conv')))
h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')
return tf.nn.sigmoid(h4), h4
else:
yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
x = conv_cond_concat(image, yb)
h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))
h0 = conv_cond_concat(h0, yb)
h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
h1 = tf.reshape(h1, [self.batch_size, -1])
h1 = tf.concat(1, [h1, y])
h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
h2 = tf.concat(1, [h2, y])
h3 = linear(h2, 1, 'd_h3_lin')
return tf.nn.sigmoid(h3), h3
答案 0 :(得分:0)
y_dim是标签的长度。
似乎在DCGAN课程中使用说“我们知道培训标签的数量'”。如果我们这样做,我们可以用已知的尺寸定义所有这些张量。
E.g。查看使用MNIST数据集的main.py文件。看到0-9标签的y_dim设置为10?这是“没有”'第二种选择。
希望这有帮助。