使用自定义损失函数时GradientTape返回None

时间:2020-01-02 20:02:04

标签: python-3.x keras tensorflow2.0

鉴于某些文本输入,我试图创建一个具有与输入相同的语义编码的输出。为此,我训练了一个自动编码器,只保留了编码器部分来比较序列嵌入。这是训练新解码器的代码:

with tf.GradientTape() as gen_tape:
    enc_output, enc_hidden = enc(input_batch, enc_hidden)
    gen_hidden = enc_hidden
    all_outputs = [[tokenizer.word_index[START_TOKEN]] * BATCH_SIZE]
    gen_input = tf.expand_dims([tokenizer.word_index[START_TOKEN]] * BATCH_SIZE, 1) #First input is list of start tokens
    gen_loss = 0
    for t in range(1, input_batch.shape[1]):
        predictions, gen_hidden, _ = gen(gen_input, gen_hidden, enc_output)
        predictions_am = tf.expand_dims(tf.argmax(predictions, 1), 1) #take most likely prediction for each row
        all_outputs.append(tf.argmax(predictions, 1))
        gen_input = predictions_am #predicted IDs are fed back into the model

    all_outputs = tf.stack(all_outputs, 1) #build list of full length predictions
    #Get the embedding vectors for original and predictions
    e1 = enc(all_outputs, enc.get_def_hidden_state())[0]
    e2 = enc_output
    gen_loss = -tf.keras.losses.cosine_similarity(e1, e2) + 1 #calculate loss based on how similar they are

gen_grads = gen_tape.gradient(gen_loss, gen.trainable_weights)
gen_optimizer.apply_gradients(zip(gen_grads, gen.trainable_weights))

gen_grads总是以无列表结尾

1 个答案:

答案 0 :(得分:0)

Argmax是不可区分的。您不能将其作为损失计算的模型输出。您需要保持一时的预测,直到最后。