我试过的张量流/改进的wgan-gp代码很快就出现了分歧

时间:2017-07-31 07:58:11

标签: tensorflow

这是我的代码:

DEPTH = 64
OUTPUT_SIZE = 28
batch_size = 16:
def Discriminator(name,inputs):

    with tf.variable_scope(name):
        output = tf.reshape(inputs, [-1, 28, 28, 1])
        output1 = conv2d('d_conv_1', output, ksize=5, out_dim=DEPTH)
        output2 = lrelu('d_lrelu_1', output1)
        output3 = conv2d('d_conv_2', output2, ksize=5, out_dim=2*DEPTH)
        output4 = lrelu('d_lrelu_2', output3)
        output5 = conv2d('d_conv_3', output4, ksize=5, out_dim=4*DEPTH)
        output6 = lrelu('d_lrelu_3', output5)
        # output7 = conv2d('d_conv_4', output6, ksize=5, out_dim=8*DEPTH)
        # output8 = lrelu('d_lrelu_4', output7)
        chanel = output6.get_shape().as_list()
        output9 = tf.reshape(output6, [batch_size, chanel[1]*chanel[2]*chanel[3]])
        output0 = fully_connected('d_fc', output9, 1)
        return output

生成器代码是:

def generator(name):
    with tf.variable_scope(name):
        noise = tf.random_normal([batch_size, 100])#.astype('float32')
        # noise = tf.constant(np.random.normal(size=(128, 128)).astype('float32'))

        noise = tf.reshape(noise, [batch_size, 100], 'noise')
        output = fully_connected('g_fc_1', noise, 2*2*8*DEPTH)
        output = tf.reshape(output, [batch_size, 2, 2, 8*DEPTH], 'g_conv')

        output = deconv2d('g_deconv_1', output, ksize=5, outshape=[batch_size, 4, 4, 4*DEPTH])
        output = tf.nn.relu(output)
        output = tf.reshape(output, [batch_size, 4, 4, 4*DEPTH])

        output = deconv2d('g_deconv_2', output, ksize=5, outshape=[batch_size, 7, 7, 2* DEPTH])
        output = tf.nn.relu(output)

        output = deconv2d('g_deconv_3', output, ksize=5, outshape=[batch_size, 14, 14, DEPTH])
        output = tf.nn.relu(output)

        output = deconv2d('g_deconv_4', output, ksize=5, outshape=[batch_size, OUTPUT_SIZE, OUTPUT_SIZE, 1])
        # output = tf.nn.relu(output)
        output = tf.nn.sigmoid(output)
        return tf.reshape(output,[-1,784])

the train code is as follows:

real_data = tf.placeholder(tf.float32, shape=[batch_size,784])

        with tf.variable_scope(tf.get_variable_scope()):

            fake_data = generator('gen')
            disc_real = Discriminator('dis_r',real_data)
            disc_fake = Discriminator('dis_f',fake_data)

        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]

        '''计算损失'''
        gen_cost = tf.reduce_mean(disc_fake)
        disc_cost = -tf.reduce_mean(disc_fake) + tf.reduce_mean(disc_real)

        alpha = tf.random_uniform(
            shape=[batch_size, 1],minval=0.,maxval=1.)
        differences = fake_data - real_data
        interpolates = real_data + (alpha * differences)
        gradients = tf.gradients(Discriminator('dis',interpolates), [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
        disc_cost += LAMBDA * gradient_penalty

        gen_train_op = tf.train.AdamOptimizer(
            learning_rate=1e-4,beta1=0.5,beta2=0.9).minimize(gen_cost,var_list=g_vars)
        disc_train_op = tf.train.AdamOptimizer(
            learning_rate=1e-4,beta1=0.5,beta2=0.9).minimize(disc_cost,var_list=d_vars)

错误日志是:

error_log

显然,代码不起作用,它分歧很快,这个问题困扰了我很久,我真的很想知道这个问题的根源。

1 个答案:

答案 0 :(得分:0)

你确定你在WGAN文件中强制执行Lipschitz约束吗?

在他们的论文中通过对鉴别器的权重进行强烈限制来完成。

Original WGAN paper