计算numpy数组中有多少元素在每个其他元素的delta内

时间:2016-12-30 10:24:59

标签: python performance pandas numpy

考虑数组x和delta变量d

np.random.seed([3,1415])
x = np.random.randint(100, size=10)
d = 10

对于x中的每个元素,我想计算每个元素中有多少其他元素在delta d 距离之内。

所以x看起来像

print(x)

[11 98 74 90 15 55 13 11 13 26]

结果应为

[5 2 1 2 5 1 5 5 5 1]

我尝试了什么
策略:

  • 使用广播来消除外部差异
  • 外差的绝对值
  • 总和多少超过阈值
(np.abs(x[:, None] - x) <= d).sum(-1)

[5 2 1 2 5 1 5 5 5 1]

这很有效。但是,它没有扩展。外部差异是O(n ^ 2)时间。如何获得不能用二次时间缩放的相同解决方案?

2 个答案:

答案 0 :(得分:4)

此帖中列出了另外两个变种,基于OP's answer post中的searchsorted strategy

def pir3(a,d):  # Short & less efficient
    sidx = a.argsort()
    p1 = a.searchsorted(a+d,'right',sorter=sidx)
    p2 = a.searchsorted(a-d,sorter=sidx)
    return p1 - p2

def pir4(a, d):   # Long & more efficient
    s = a.argsort()

    y = np.empty(s.size,dtype=np.int64)
    y[s] = np.arange(s.size)

    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[y]

更有效的方法可以从this post获得s.argsort()的有效想法。

运行时测试 -

In [155]: # Inputs
     ...: a = np.random.randint(0,1000000,(10000))
     ...: d = 10


In [156]: %timeit pir2(a,d) #@ piRSquared's post solution
     ...: %timeit pir3(a,d)
     ...: %timeit pir4(a,d)
     ...: 
100 loops, best of 3: 2.43 ms per loop
100 loops, best of 3: 4.44 ms per loop
1000 loops, best of 3: 1.66 ms per loop

答案 1 :(得分:1)

策略

  • 由于x未必排序,我们会对其进行排序并通过argsort跟踪排序排列,以便我们可以撤消排列。
  • 我们将np.searchsorted上的xx - d一起使用,以找到x的值开始超过x - d时的起始位置。
  • 在另一方再做一次,除非我们必须使用np.searchsorted参数side='right'并使用x + d
  • 利用右侧和左侧搜索范围之间的差异来计算每个元素+/- d内元素的数量
  • 使用argsort来反转排序排列

将问题定义为pir1

def pir1(a, d):
    return (np.abs(a[:, None] - a) <= d).sum(-1)

我们将定义一个新函数pir2

def pir2(a, d):
    s = x.argsort()
    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[s.argsort()]

演示

pir1(x, d)

[5 2 1 2 5 1 5 5 5 1]    
pir1(x, d)

[5 2 1 2 5 1 5 5 5 1]    

<强> 定时
pir2是明显的赢家!

代码

功能

def pir1(a, d):
    return (np.abs(a[:, None] - a) <= d).sum(-1)

def pir2(a, d):
    s = x.argsort()
    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[s.argsort()]

#######################
# From Divakar's post #
#######################
def pir3(a,d):  # Short & less efficient
    sidx = a.argsort()
    p1 = a.searchsorted(a+d,'right',sorter=sidx)
    p2 = a.searchsorted(a-d,sorter=sidx)
    return p1 - p2

def pir4(a, d):   # Long & more efficient
    s = a.argsort()

    y = np.empty(s.size,dtype=np.int64)
    y[s] = np.arange(s.size)

    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[y]

测试

from timeit import timeit

results = pd.DataFrame(
    index=np.arange(1, 50),
    columns=['pir%s' %i for i in range(1, 5)])

for i in results.index:
    np.random.seed([3,1415])
    x = np.random.randint(1000000, size=i)
    for j in results.columns:
        setup = 'from __main__ import x, {}'.format(j)
        results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=10000)

results.plot()

enter image description here

扩展到更大的数组
摆脱pir1

from timeit import timeit

results = pd.DataFrame(
    index=np.arange(1, 11) * 1000,
    columns=['pir%s' %i for i in range(2, 5)])

for i in results.index:
    np.random.seed([3,1415])
    x = np.random.randint(1000000, size=i)
    for j in results.columns:
        setup = 'from __main__ import x, {}'.format(j)
        results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=100)

results.insert(0, 'pir1', 0)

results.plot()

enter image description here