我有一个形状为a
的numpy数组(n, 3)
,其中包含从0
到m
的整数。 m
和n
都可能相当大。众所周知,从0
到m
的每个整数有时只出现一次,但大多数在a
的某个地方出现两次。连续没有加倍的索引。
我现在想构建“反向”索引,即形状b_row
的两个数组b_col
和(m, 2)
,每行包含(一个或两个)行/ a
中的列索引row_idx
中显示a
。
这有效:
import numpy
a = numpy.array([
[0, 1, 2],
[0, 1, 3],
[2, 3, 4],
[4, 5, 6],
# ...
])
print(a)
b_row = -numpy.ones((7, 2), dtype=int)
b_col = -numpy.ones((7, 2), dtype=int)
count = numpy.zeros(7, dtype=int)
for k, row in enumerate(a):
i = count[row]
b_row[row, i] = k
b_col[row, i] = [0, 1, 2]
count[row] += 1
print(b_row)
print(b_col)
[[0 1 2]
[0 1 3]
[2 3 4]
[4 5 6]]
[[ 0 1]
[ 0 1]
[ 0 2]
[ 1 2]
[ 2 3]
[ 3 -1]
[ 3 -1]]
[[ 0 0]
[ 1 1]
[ 2 0]
[ 2 1]
[ 2 0]
[ 1 -1]
[ 2 -1]]
但由于a
上的显式循环而缓慢。
有关如何提高速度的任何提示?
答案 0 :(得分:2)
这是一个解决方案:
import numpy as np
m = 7
a = np.array([
[0, 1, 2],
[0, 1, 3],
[2, 3, 4],
[4, 5, 6],
# ...
])
print('a:')
print(a)
a_flat = a.flatten() # Or a.ravel() if can modify original array
v1, idx1 = np.unique(a_flat, return_index=True)
a_flat[idx1] = -1
v2, idx2 = np.unique(a_flat, return_index=True)
v2, idx2 = v2[1:], idx2[1:]
rows1, cols1 = np.unravel_index(idx1, a.shape)
rows2, cols2 = np.unravel_index(idx2, a.shape)
b_row = -np.ones((m, 2), dtype=int)
b_col = -np.ones((m, 2), dtype=int)
b_row[v1, 0] = rows1
b_col[v1, 0] = cols1
b_row[v2, 1] = rows2
b_col[v2, 1] = cols2
print('b_row:')
print(b_row)
print('b_col:')
print(b_col)
输出:
a:
[[0 1 2]
[0 1 3]
[2 3 4]
[4 5 6]]
b_row:
[[ 0 1]
[ 0 1]
[ 0 2]
[ 1 2]
[ 2 3]
[ 3 -1]
[ 3 -1]]
b_col:
[[ 0 0]
[ 1 1]
[ 2 0]
[ 2 1]
[ 2 0]
[ 1 -1]
[ 2 -1]]
编辑:
IPython中的一个小基准用于比较。正如@eozd所示,由于{(1)}在O(n)中运行,算法复杂度原则上更高,但对于实际大小,矢量化解决方案似乎仍然快得多:
np.unique
答案 1 :(得分:1)
这是一个仅使用一个argsort
和一系列轻量级索引操作的解决方案:
def grp_start_len(a):
# https://stackoverflow.com/a/50394587/353337
m = numpy.concatenate([[True], a[:-1] != a[1:], [True]])
idx = numpy.flatnonzero(m)
return idx[:-1], numpy.diff(idx)
a_flat = a.flatten()
idx_sort = numpy.argsort(a_flat)
idx_start, count = grp_start_len(a_flat[idx_sort])
res1 = idx_sort[idx_start[count==1]][:, numpy.newaxis]
res1 // 3
res1 % 3
idx = idx_start[count==2]
res2 = numpy.column_stack([idx_sort[idx], idx_sort[idx + 1]])
res2 // 3
res2 % 3
基本思想是,在a
被展平和排序后,所有信息都可以从a_flat_sorted
中的起始索引和整数块的长度中提取。