鉴于任何一般float torch.Tensor
,可能包含一些 NaN 值,我正在寻找一种有效的方法来将其中的所有NaN值替换为零,或者将它们全部删除并过滤在另一个新的Tensor中输出“有用”的值。
我知道这样做的一个简单方法是手动迭代给定张量中的所有值(并相应地用零替换它们或者为新张量拒绝它们)。
是否有一些预定义的Torch功能或功能组合可以在性能方面更有效地实现这一点,这依赖于Torch固有的CPU-GPU优化?
答案 0 :(得分:3)
好吧,看起来torch
中没有函数检查NaN的张量。但是因为NaN!= NaN,所以有一个解决方法:
a = torch.rand(4, 5)
a[2][3] = tonumber('nan')
nan_mask = a:ne(a)
notnan_mask = a:eq(a)
print(a)
0.2434 0.1731 0.3440 0.3340 0.0519
0.0932 0.4067 nan 0.1827 0.5945
0.3020 0.1035 0.5415 0.3329 0.7881
0.6108 0.9498 0.0406 0.9335 0.3582
[torch.DoubleTensor of size 4x5]
print(nan_mask)
0 0 0 0 0
0 0 1 0 0
0 0 0 0 0
0 0 0 0 0
[torch.ByteTensor of size 4x5]
拥有这些面具,您可以有效地提取NaN /而不是NaN值并将其替换为您想要的任何值:
print(a[notnan_mask])
...
[torch.DoubleTensor of size 19]
a[nan_mask] = 42
print(a)
0.2434 0.1731 0.3440 0.3340 0.0519
0.0932 0.4067 42.0000 0.1827 0.5945
0.3020 0.1035 0.5415 0.3329 0.7881
0.6108 0.9498 0.0406 0.9335 0.3582
[torch.DoubleTensor of size 4x5]