此问题与许多与警告RuntimeWarning: invalid value encountered in greater/less/etc
然而,我无法找到解决我的特定问题的方法,我认为应该有一个。
所以,我有numpy.ndarray
类似于这个:
array([[ nan, 1., nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
...,
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]])
我想计算array > 0.5
,它会给出我想要的结果,但是会有与nan
进行比较的警告:
__main__:1: RuntimeWarning: invalid value encountered in greater
Out[68]:
array([[False, True, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], dtype=bool)
我基本上想要计算array > 0.5
,但没有出现警告。
我的限制:
with np.errstate(invalid='ignore'):
我想出了一个简单的解决方案:
nan
(array[np.isnan(array)] = -np.inf
),在进行比较后将其恢复(array[array == -np.inf] = np.nan
)但是我觉得这些计算只是浪费时间(我认为)它应该存在直接的方法来实现这一点。我一直在探索numpy.ma
模块和numpy.where
函数,但我无法找到我想要的“直接”解决方案。
对此有何想法?
答案 0 :(得分:5)
只要比较包含至少一个NaN的数组,就会发出警告。解决方案是使用masking
仅比较非NaN元素,我们将尝试使用通用解决方案在comparison based NumPy ufuncs
的帮助下涵盖所有类型的比较,如下所示 -
def compare_nan_array(func, a, thresh):
out = ~np.isnan(a)
out[out] = func(a[out] , thresh)
return out
这个想法是:
获取非NaN的掩码。
使用它从输入数组中获取非NaN值。然后执行所需的比较(大于,大于等等)以获得另一个掩码,该掩码表示掩蔽位置的比较掩码输出。
使用它来优化非NaN的掩码,这是最终输出。
示例运行 -
In [41]: np.random.seed(0)
In [42]: a = np.random.randint(0,9,(4,5)).astype(float)
In [43]: a.ravel()[np.random.choice(a.size, 16, replace=0)] = np.nan
In [44]: a
Out[44]:
array([[ nan, nan, nan, nan, nan],
[ nan, nan, nan, 4., 7.],
[ nan, nan, nan, 1., nan],
[ nan, 7., nan, nan, nan]])
In [45]: a > 5 # Shows warning with the usual comparison
__main__:1: RuntimeWarning: invalid value encountered in greater
Out[45]:
array([[False, False, False, False, False],
[False, False, False, False, True],
[False, False, False, False, False],
[False, True, False, False, False]], dtype=bool)
# With suggested masking based method
In [46]: compare_nan_array(np.greater, a, 5)
Out[46]:
array([[False, False, False, False, False],
[False, False, False, False, True],
[False, False, False, False, False],
[False, True, False, False, False]], dtype=bool)
让我们通过测试lesser than 5
-
In [47]: a < 5
__main__:1: RuntimeWarning: invalid value encountered in less
Out[47]:
array([[False, False, False, False, False],
[False, False, False, True, False],
[False, False, False, True, False],
[False, False, False, False, False]], dtype=bool)
In [48]: compare_nan_array(np.less, a, 5)
Out[48]:
array([[False, False, False, False, False],
[False, False, False, True, False],
[False, False, False, True, False],
[False, False, False, False, False]], dtype=bool)
答案 1 :(得分:0)
有一种更好的方法-您不想永远抑制警告,因为它可以帮助您以后找到其他错误。
遵循此问题中的建议:RuntimeWarning: invalid value encountered in divide
如果结果是您想要的结果,则只需编写:
with np.errstate(invalid='ignore'):
result = (array > 0.5)
# ... use result, and your warnings are not suppressed.
否则,您可以通过复制数组来满足限制:
to_compare = array.copy()
to_compare[np.isnan(to_compare)] = 0.5 # you don't need -np.inf, anything <= 0.5 is OK
result = (to_compare > 0.5)
您无需“恢复”阵列中的NaN。