如何找到两个numpy浮点数组的交集?:
a = np.arange(2, 3, 0.1)
b = np.array([2.3, 2.4, 2.5])
out_data = np.intersect1d(a, b)
结果是
out_data -> ndarray: []
答案 0 :(得分:4)
由于浮点数的工作方式,在您的示例中,[3]不是2.3,而是2.3000000000000003。这是因为0.1在IEEE双精度浮点数中没有精确的表示。 numpy中的intersect1d
方法实际上只适用于整数。要解决这个问题,您应该实现自己的方法,该方法需要一个容差来确定两个浮点数是否足够接近。
答案 1 :(得分:2)
这是使用NumPy's broadcasting capability
-
tol = 1e-5 # tolerance
out = b[(np.abs(a[:,None] - b) < tol).any(0)]
示例运行 -
In [31]: a
Out[31]: array([ 2. , 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9])
In [32]: b
Out[32]: array([ 2.3 , 2.4 , 2.5 , 2.25, 2.1 ])
In [33]: tol = 1e-5 # tolerance
In [34]: b[(np.abs(a[:,None] - b) < tol).any(0)]
Out[34]: array([ 2.3, 2.4, 2.5, 2.1])
答案 2 :(得分:0)
以下例程将返回相对于列表a。
的指定容差范围内的公共值索引def findOverlap(self, a, b, rtol = 1e-05, atol = 1e-08, equal_nan = False):
overlap_indexes = []
for i, item_a in enumerate(a):
for item_b in b:
if np.isclose(item_a, item_b, rtol = rtol, atol = atol, equal_nan = equal_nan):
overlap_indexes.append(i)
return overlap_indexes
例如
a = np.arange(2, 3, 0.1).tolist()
b = np.array([2.3, 2.4, 2.5]).tolist()
self.findOverlap(a, b)
-> overlap_indexes:[3, 4, 5]