找到numpy浮点数组的交集

时间:2015-09-10 23:59:39

标签: python numpy

如何找到两个numpy浮点数组的交集?:

a = np.arange(2, 3, 0.1)
b = np.array([2.3, 2.4, 2.5])
out_data = np.intersect1d(a, b)

结果是

out_data -> ndarray: []

3 个答案:

答案 0 :(得分:4)

由于浮点数的工作方式,在您的示例中,[3]不是2.3,而是2.3000000000000003。这是因为0.1在IEEE双精度浮点数中没有精确的表示。 numpy中的intersect1d方法实际上只适用于整数。要解决这个问题,您应该实现自己的方法,该方法需要一个容差来确定两个浮点数是否足够接近。

答案 1 :(得分:2)

这是使用NumPy's broadcasting capability -

的矢量化方法
tol = 1e-5 # tolerance
out = b[(np.abs(a[:,None] - b) < tol).any(0)]

示例运行 -

In [31]: a
Out[31]: array([ 2. ,  2.1,  2.2,  2.3,  2.4,  2.5,  2.6,  2.7,  2.8,  2.9])

In [32]: b
Out[32]: array([ 2.3 ,  2.4 ,  2.5 ,  2.25,  2.1 ])

In [33]: tol = 1e-5 # tolerance

In [34]: b[(np.abs(a[:,None] - b) < tol).any(0)]
Out[34]: array([ 2.3,  2.4,  2.5,  2.1])

答案 2 :(得分:0)

以下例程将返回相对于列表a。

的指定容差范围内的公共值索引
def findOverlap(self, a, b, rtol = 1e-05, atol = 1e-08, equal_nan = False):
    overlap_indexes = []
    for i, item_a in enumerate(a):
        for item_b in b:
            if np.isclose(item_a, item_b, rtol = rtol, atol = atol, equal_nan = equal_nan):
                overlap_indexes.append(i)
    return overlap_indexes

例如

a = np.arange(2, 3, 0.1).tolist()
b = np.array([2.3, 2.4, 2.5]).tolist()
self.findOverlap(a, b)

-> overlap_indexes:[3, 4, 5]