为什么自定义训练循环平均损失不超过batch_size?

时间:2021-01-04 01:46:18

标签: tensorflow machine-learning deep-learning

下面的代码片段是来自 Tensorflow 官方教程的自定义训练循环。https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch。另一个教程也没有平均损失超过 batch_size,如下所示 https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough

为什么在这一行 loss_value = loss_fn(y_batch_train, logits) 处的 loss_value 没有在 batch_size 上取平均值?这是一个错误吗?从这里的另一个问题 Loss function works with reduce_mean but not reduce_sum,确实需要 reduce_mean 来平均损失超过 batch_size

loss_fn 在教程中定义如下。它显然没有超过 batch_size 的平均值。

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

从文档中,keras.losses.SparseCategoricalCrossentropy 在没有平均的情况下对批次的损失求和。因此,这本质上是 reduce_sum 而不是 reduce_mean

Type of tf.keras.losses.Reduction to apply to loss. Default value is AUTO. AUTO indicates that the reduction option will be determined by the usage context. For almost all cases this defaults to SUM_OVER_BATCH_SIZE.

代码如下所示。

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * 64))

1 个答案:

答案 0 :(得分:0)

我已经弄清楚了,默认情况下 loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) 确实平均损失超过 batch_size。