如何避免fp16
中的tf.linalg.norm
中间溢出而最终结果没有溢出?
我正在尝试可视化神经网络的优化表面。为此,我正在评估关于相同的x
和y
的数千个参数设置的损失。我将fp16
(即半精度)用于较大的批量。问题是,计算损失的最后一步(即tf.linalg.norm
)将(几乎总是)溢出并返回inf
:
>>> tf.linalg.norm(myArray, axes=(1,2))
[inf, inf, inf, ..., inf]
一种解决方案是在计算之前将整个数组cast
fp32
(即单精度):
>>> tf.linalg.norm(tf.cast(myArray, tf.float32), axes=(1,2))
[~10k, ~10k, ~10k, ..., ~10k]
但是我相信此过程会尝试在GPU内存中分配两倍大的数组(并保留旧数组,直到完成cast
为止)。如果GPU内存不足,则该过程将失败。我认为应该有更多的内存有效(和更快)方式,但是我不确定如何。