单元格列表是一种数据结构,用于维护N-D网格中的数据点列表。例如,以下2d索引列表:
ind = [(0, 1), (1, 0), (0, 1), (0, 0), (0, 0), (0, 0), (1, 1)]
转换为以下2x2单元格列表:
cell = [[[3, 4, 5], [0, 2]],
[[1, ], [6, ]]
]
使用O(n)算法:
# create an empty 2x2 cell list
cell = [[[] for _ in range(2)] for _ in range(2)]
for i in range(len(ind)):
cell[ind[i][0], ind[i][1]].append(i)
numpy中是否有矢量化的方法可以将索引列表(ind
)转换为上述单元格结构?
答案 0 :(得分:0)
我不认为有一个很好的纯numpy
,但是您可以使用pythran
或---如果您不想接触编译器--- scipy.sparse
cf. this Q&A本质上是您问题的一维版本。
[stb_pthr.py]从Most efficient way to sort an array into bins specified by an index array?简化
import numpy as np
#pythran export sort_to_bins(int[:], int)
def sort_to_bins(idx, mx=-1):
if mx==-1:
mx = idx.max() + 1
cnts = np.zeros(mx + 1, int)
for i in range(idx.size):
cnts[idx[i] + 1] += 1
for i in range(1, cnts.size):
cnts[i] += cnts[i-1]
res = np.empty_like(idx)
for i in range(idx.size):
res[cnts[idx[i]]] = i
cnts[idx[i]] += 1
return res, cnts[:-1]
编译:{{1}}
主脚本:
pythran stb_pthr.py
示例运行,输出是OP的玩具示例的答案,以及import numpy as np
try:
from stb_pthr import sort_to_bins
HAVE_PYTHRAN = True
except:
HAVE_PYTHRAN = False
from scipy import sparse
def fallback(flat, maxind):
res = sparse.csr_matrix((np.zeros_like(flat),flat,np.arange(len(flat)+1)),
(len(flat), maxind)).tocsc()
return res.indices, res.indptr[1:-1]
if not HAVE_PYTHRAN:
sort_to_bins = fallback
def to_cell(data, shape=None):
data = np.asanyarray(data)
if shape is None:
*shape, = (1 + c.max() for c in data.T)
flat = np.ravel_multi_index((*data.T,), shape)
reord, bnds = sort_to_bins(flat, np.prod(shape))
return np.frompyfunc(np.split(reord, bnds).__getitem__, 1, 1)(
np.arange(np.prod(shape)).reshape(shape))
ind = [(0, 1), (1, 0), (0, 1), (0, 0), (0, 0), (0, 0), (1, 1)]
print(to_cell(ind))
from timeit import timeit
ind = np.transpose(np.unravel_index(np.random.randint(0, 100, (1_000_000)), (10, 10)))
if HAVE_PYTHRAN:
print(timeit(lambda: to_cell(ind), number=10)*100)
sort_to_bins = fallback # !!! MUST REMOVE THIS LINE AFTER TESTING
print(timeit(lambda: to_cell(ind), number=10)*100)
和pythran
解决方案在1,000,000点示例上的计时(以毫秒为单位):
scipy