Pytorch操作检测NaNs

时间:2018-01-08 21:03:25

标签: pytorch

是否有Pytorch内部程序来检测张量中的NaN? Tensorflow有tf.is_nantf.check_numerics操作...... Pytorch在某处有类似的东西吗?我在文档中找不到这样的东西......

我正在寻找一个Pytorch内部例程,因为我希望这发生在GPU和CPU上。这不包括基于numpy的解决方案(如np.isnan(sometensor.numpy()).any())......

5 个答案:

答案 0 :(得分:17)

您始终可以利用nan != nan

这一事实
>>> x = torch.tensor([1, 2, np.nan])
tensor([  1.,   2., nan.])
>>> x != x
tensor([ 0,  0,  1], dtype=torch.uint8)

使用pytorch 0.4还有torch.isnan

>>> torch.isnan(x)
tensor([ 0,  0,  1], dtype=torch.uint8)

答案 1 :(得分:12)

从PyTorch 0.4.1开始,有detect_anomaly上下文管理器,它自动在后向传播的所有步骤之间插入与assert not torch.isnan(grad).any()等效的声明。当在反向传递过程中出现问题时,这非常有用。

答案 2 :(得分:7)

正如@cleros在对@nemo答案的评论中所建议的那样,您可以使用any()运算符将其作为布尔值获得:

torch.isnan(your_tensor).any()

答案 3 :(得分:3)

如果任何值为 nan 则为真:

torch.any(tensor.isnan())

如果都是 nan 则为真:

torch.all(tensor.isnan())

答案 4 :(得分:1)

如果你想直接在张量上调用它:

import torch

x = torch.randn(5, 4)
print(x.isnan().any())

出:

import torch
x = torch.randn(5, 4)
print(x.isnan().any())
tensor(False)