具有不同损失功能的TensorFlow自定义训练步骤

时间:2020-11-09 21:39:11

标签: python tensorflow machine-learning keras deep-learning

背景

根据TensorFlow documentation,可以使用以下内容执行自定义训练步骤

# Fake sample data for testing
x_batch_train = tf.zeros([32, 3, 1], dtype="float32")
y_batch_train = tf.zeros([32], dtype="float32")
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
with tf.GradientTape() as tape:
    logits = model(x_batch_train, training=True)
    loss_value = loss_fn(y_batch_train, logits)

grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

但是,如果我想使用不同的损失函数(如分类交叉熵),则需要使梯度磁带中创建的对数达到argmax:

loss_fn = tf.keras.lossees.get("categorical_crossentropy")
with tf.GradientTape() as tape:
    logits = model(x_batch_train, training=True)
    prediction = tf.cast(tf.argmax(logits, axis=-1), y_batch_train.dtype)
    loss_value = loss_fn(y_batch_train, prediction)

grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

问题

此问题是tf.argmax函数不可微,因此TensorFlow将无法计算梯度,并且您会得到错误:

ValueError: No gradients provided for any variable: [...]

我的问题:在不更改损失函数的情况下,如何使第二个示例起作用?

1 个答案:

答案 0 :(得分:2)

categorical_crossentropy希望您的标签是一个热编码的标签,因此您应该首先确保这一点。然后直接传递模型的结果,该输出应该是每个类别一个概率更多信息-> https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy#standalone_usage