我想通过给定指标(x和y轴)修改空位图。 对于指标给出的每个坐标,该值应该增加一个。
到目前为止,一切似乎都很好。但如果我的指标数组中有一些类似的指标,它只会提高一次值。
>>> img
array([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
>>> inds
array([[0, 0],
[3, 4],
[3, 4]])
操作:
>>> img[inds[:,1], inds[:,0]] += 1
结果:
>>> img
array([[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 1, 0]])
预期结果:
>>> img
array([[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 2, 0]])
有人知道如何解决这个问题吗?优选地,不使用循环的快速方法。
答案 0 :(得分:6)
这是一种方式。计算算法courtesy of @AlexRiley。
有关img
和inds
相对大小的效果影响,请参阅@PaulPanzer's answer。
# count occurrences of each row and return array
counts = (inds[:, None] == inds).all(axis=2).sum(axis=1)
# apply indices and counts
img[inds[:,1], inds[:,0]] += counts
print(img)
array([[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 2, 0]])
答案 1 :(得分:5)
你可以使用numpy.add.at
进行一些操作来准备指数。
np.add.at(img, tuple(inds[:, [1, 0]].T), 1)
如果你有更大的inds
数组,这种方法应该保持快速...(尽管Paul Panzer's solution更快)
答案 2 :(得分:4)
关于另外两个答案的两个评论:
1)使用np.unique
和axis
关键字return_counts
可以改善@ jpp。
2)如果我们转换为平面索引,我们可以使用np.bincount
,这通常(但并非总是如此,请参见基准测试中的最后一个测试用例)比np.add.at
更快。
感谢@miradulo获得基准测试的初始版本。
import numpy as np
def jpp(img, inds):
counts = (inds[:, None] == inds).all(axis=2).sum(axis=1)
img[inds[:,1], inds[:,0]] += counts
def jpp_pp(img, inds):
unq, cnts = np.unique(inds, axis=0, return_counts=True)
img[unq[:,1], unq[:,0]] += cnts
def miradulo(img, inds):
np.add.at(img, tuple(inds[:, [1, 0]].T), 1)
def pp(img, inds):
imgf = img.ravel()
indsf = np.ravel_multi_index(inds.T[::-1], img.shape[::-1])
imgf += np.bincount(indsf, None, img.size)
inds = np.random.randint(0, 5, (3, 2))
big_inds = np.random.randint(0, 5, (10000, 2))
sml_inds = np.random.randint(0, 1000, (5, 2))
from timeit import timeit
for f in jpp, jpp_pp, miradulo, pp:
print(f.__name__)
for i, n, a in [(inds, 1000, 5), (big_inds, 10, 5), (sml_inds, 10, 1000)]:
img = np.zeros((a, a), int)
print(timeit("f(img, i)", globals=dict(img=img, i=i, f=f), number=n) * 1000 / n, 'ms')
输出:
jpp
0.011815106990979984 ms
2623.5026352020213 ms
0.04642329877242446 ms
jpp_pp
0.041291153989732265 ms
5.418520100647584 ms
0.05826510023325682 ms
miradulo
0.007099648006260395 ms
0.7788308983435854 ms
0.009103797492571175 ms
pp
0.0035401539935264736 ms
0.06540440081153065 ms
3.486583800986409 ms