这是慢速代码:
Error in `[.data.table`(X, Y, .N, on = .(x >= y & x <= y + 3), by = .EACHI) :
Column(s) [y & x] not found in i
有没有办法一次性完成/使其更快?
答案 0 :(得分:1)
向量化可能很困难或不可能。这里的提示是第二维中的高级索引,例如maskB & maskA1
,每行可以有任意的True
值。因此,您无法隔离m x n
数组进行索引。
使用Custom Resource Handler的简单for
循环似乎将性能提高了一个因素:
# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np
from numba import njit
@njit
def doCounts(maskA1, maskA2, maskA3, counts, maskB):
mask1, mask2, mask3 = maskB & maskA1, maskB & maskA2, maskB & maskA3
for i in range(counts.shape[0]):
m1, m2, m3 = mask1[i], mask2[i], mask3[i]
for j in range(counts.shape[1]):
if m1:
counts[0, j] += 1
if m2:
counts[1, j] += 1
if m3:
counts[2, j] += 1
return counts
def doCounts_original(maskA1, maskA2, maskA3, counts, maskB):
counts[0, maskB & maskA1] += 1
counts[1, maskB & maskA2] += 1
counts[2, maskB & maskA3] += 1
return counts
n = 100
np.random.seed(0)
m1, m2, m3, mB = (np.random.randint(0, 2, n**3).astype(bool) for _ in range(4))
counts = np.random.randint(0, 100, (3, n**3))
assert np.array_equal(doCounts(m1, m2, m3, counts, mB),
doCounts_original(m1, m2, m3, counts, mB))
%timeit doCounts(m1, m2, m3, counts, mB) # 5.36 ms
%timeit doCounts_original(m1, m2, m3, counts, mB) # 40.2 ms