混合精度训练导致 NaN-loss

时间:2021-04-19 09:13:04

标签: tensorflow keras

我一直在关注Mixed Precision Guide。因此,我正在设置:

keras.mixed_precision.set_global_policy(mixed_precision)

并像这样包装优化器:

if mixed_precision.startswith('mixed'):
    logger.info(f'Using LossScaleOptimizer for mixed-precision policy "{mixed_precision}"')
    optimizer = keras.mixed_precision.LossScaleOptimizer(optimizer)

我的模型有简单的 Dense 层作为输出,我将其设置为“float32”

# Set dtype explicitly in last layer for mixed-precision training (float32 for numeric stability).
self.output_dense = layers.Dense(vocab_size, dtype=tf.float32)

和我修改为这个的自定义 train_step() 实现:

with tf.GradientTape() as tape:
    model_loss = self.loss_fn(
        inputs,
        y_true=y_true,
        mask=mask
    )

    is_mixed_precision = isinstance(self.optimizer, mixed_precision.LossScaleOptimizer)

    # We always want to return the unmodified model_loss for Tensorboard
    if is_mixed_precision:
        loss = self.optimizer.get_scaled_loss(model_loss)
    else:
        loss = model_loss

    gradients = tape.gradient(loss, self.trainable_variables)

    if is_mixed_precision:
        gradients = self.optimizer.get_unscaled_gradients(gradients)

    return model_loss, gradients

然而,一段时间后我的损失仍然变成NaN

enter image description here

外面我在确认保单是否被模型认可:

logger.info(f'Mixed-precision policy: {mixed_precision}')
logger.info(f'Compute dtype:          {model.compute_dtype}')
logger.info(f'Variable dtype:         {model.variable_dtype}')
keras.py:216] Mixed-precision policy: mixed_float16
keras.py:217] Compute dtype:          float16
keras.py:218] Variable dtype:         float32

但我可以说这是由于 NaN 损失..

有什么明显的我做错或遗漏了吗?知道我如何在这里追踪问题吗?

1 个答案:

答案 0 :(得分:0)

经过一些反思,我想我能够找到问题所在。它位于我自定义的多头注意力层中。更具体地说,问题似乎是我使用 value.dtype.min 将掩码应用于 logits 的掩码,例如:

logits += value.dtype.min * (1.0 - mask)

有趣的是,这甚至一开始就奏效了。考虑一下,您从一开始就很有可能会出现下溢,但我能够训练模型一段时间,直到 NaN 发生。

无论如何,我的解决方案是给它一些空间,所以我只需将 dtype 的最小值除以二:

logits += (value.dtype.min / 2.0) * (1.0 - mask)