加快简单的多维计数器代码

时间:2018-10-19 23:35:08

标签: python performance numpy optimization indexing

这是慢速代码:

Error in `[.data.table`(X, Y, .N, on = .(x >= y & x <= y + 3), by = .EACHI) : 
  Column(s) [y & x] not found in i

有没有办法一次性完成/使其更快?

1 个答案:

答案 0 :(得分:1)

向量化可能很困难或不可能。这里的提示是第二维中的高级索引,例如maskB & maskA1,每行可以有任意的True值。因此,您无法隔离m x n数组进行索引。

使用Custom Resource Handler的简单for循环似乎将性能提高了一个因素:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

import numpy as np
from numba import njit

@njit
def doCounts(maskA1, maskA2, maskA3, counts, maskB):
    mask1, mask2, mask3 = maskB & maskA1, maskB & maskA2, maskB & maskA3
    for i in range(counts.shape[0]):
        m1, m2, m3 = mask1[i], mask2[i], mask3[i]
        for j in range(counts.shape[1]):
            if m1:
                counts[0, j] += 1
            if m2:
                counts[1, j] += 1
            if m3:
                counts[2, j] += 1
    return counts

def doCounts_original(maskA1, maskA2, maskA3, counts, maskB):
    counts[0, maskB & maskA1] += 1
    counts[1, maskB & maskA2] += 1
    counts[2, maskB & maskA3] += 1
    return counts

n = 100
np.random.seed(0)
m1, m2, m3, mB = (np.random.randint(0, 2, n**3).astype(bool) for _ in range(4))
counts = np.random.randint(0, 100, (3, n**3))

assert np.array_equal(doCounts(m1, m2, m3, counts, mB),
                      doCounts_original(m1, m2, m3, counts, mB))

%timeit doCounts(m1, m2, m3, counts, mB)           # 5.36 ms
%timeit doCounts_original(m1, m2, m3, counts, mB)  # 40.2 ms