在自定义TextGAN中采样后无法计算RNN生成器的梯度

时间:2018-12-01 09:25:00

标签: python tensorflow

在Text GAN实验中计算梯度时遇到一些问题。设置如下(使用TensorFlow Eager执行):

  1. 文本被传递到RNN编码器。
  2. RNN解码器:
    • 我用RNN编码器隐藏状态为解码器加注
    • 我用一个开始令牌启动解码器,并采样max_seq_length,在每个时间步长将输出作为输入反馈
  3. 我将解码后的字符串传递给鉴别器,并在其中执行常规操作。

现在-问题是,当我尝试计算相对于生成器部分(编码器+解码器样本)的鉴别器损耗的梯度时,GradientTape仅返回一个None值的列表。但是,如果我尝试计算与鉴别器有关的损耗梯度,则它会起作用。我也在预训练能正常工作的生成器(编码器/解码器)。

供参考;编码器/解码器几乎是this官方TensorFlow示例的复制/粘贴。下面的代码在TensorFlow示例之后 运行,因为我使用该示例对编码器/解码器进行了预训练。

为了使它正常工作,我一直在努力工作,以至于代码可能有点“丑陋”,但这是无法正常工作的部分:

for epoch in range(EPOCHS):  
    start = time.time()  
    hidden = encoder.initialize_hidden_state()  
    total_generator_loss = 0  
    total_discriminator_loss = 0  
    for (batch, (inp, orig, targ)) in enumerate(dataset):
        with tf.GradientTape() as tape:  
            enc_output, enc_hidden = encoder(inp, hidden)  
            dec_hidden = enc_hidden  
            results = tf.convert_to_tensor(np.array(
                [[original_sentence.word2idx['<start>']] 
                for _ in range(BATCH_SIZE)], dtype=np.int64))

            #
            # I've also tried wrapping the below loop inside a tf.while_loop,
            # though I may have done it incorrectly...
            #
            for _ in range(1, max_length_orig):
                dec_input = tf.expand_dims(results[:, -1], 1)
                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
                results = tf.concat([results, tf.multinomial(predictions, num_samples=1)], 1)

            fake_logits = rnn_discriminator(results)
            generator_loss = losses.generator_loss_function(fake_logits)  

        generator_variables = encoder.variables + decoder.variables

        #
        # The line below is the one that's producing [None, ..., None]
        #
        generator_gradients = tape.gradient(generator_loss, generator_variables)  
        generator_optimizer.apply_gradients(zip(generator_gradients, generator_variables))  

        # 
        # The part below here is working
        #
        with tf.GradientTape() as tape:  
            target_logits = rnn_discriminator(targ)
            discriminator_loss = losses.discriminator_loss_function(fake_logits, target_logits)  

        discriminator_gradients = tape.gradient(discriminator_loss, rnn_discriminator.variables)  
        discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, rnn_discriminator.variables)
        )

        total_generator_loss += generator_loss  
        total_discriminator_loss += discriminator_loss

修改: 我意识到tf.multinomial运算可能不可微,这就是为什么渐变不会流过该点的原因。 但是,我还没有弄清楚如何通过此操作-想法非常感谢!

0 个答案:

没有答案