我有一个2D python数组,我想以奇怪的方式切片 - 我希望从每行的不同位置开始一个恒定的宽度切片。如果可能的话,我想以矢量化的方式做到这一点。
e.g。我的数组A=np.array([range(5), range(5)])
看起来像
array([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]])
我想按如下方式对其进行切片:每行2个元素,从0和3位开始。起始位置存储在b=np.array([0,3])
中。因此,期望的输出是:np.array([[0,1],[3,4]])
,即
array([[0, 1],
[3, 4]])
我试图得到这个结果的显而易见的事情是A[:,b:b+2]
,但这不起作用,我找不到任何可能的结果。
速度非常重要,因为它可以在一个循环中以较大的数组运行,而且我不想将代码的其他部分瓶颈。
答案 0 :(得分:2)
您可以使用np.take()
:
In [21]: slices = np.dstack([b, b+1])
In [22]: np.take(arr, slices)
Out[22]:
array([[[0, 1],
[3, 4]]])
答案 1 :(得分:2)
方法#1:以下是一种方法broadcasting
获取所有索引,然后使用advanced-indexing
提取这些索引 -
def take_per_row(A, indx, num_elem=2):
all_indx = indx[:,None] + np.arange(num_elem)
return A[np.arange(all_indx.shape[0])[:,None], all_indx]
示例运行 -
In [340]: A
Out[340]:
array([[0, 5, 2, 6, 3, 7, 0, 0],
[3, 2, 3, 1, 3, 1, 3, 7],
[1, 7, 4, 0, 5, 1, 5, 4],
[0, 8, 8, 6, 8, 6, 3, 1],
[2, 5, 2, 5, 6, 7, 4, 3]])
In [341]: indx = np.array([0,3,1,5,2])
In [342]: take_per_row(A, indx)
Out[342]:
array([[0, 5],
[1, 3],
[7, 4],
[6, 3],
[2, 5]])
方法#2:使用np.lib.stride_tricks.as_strided
-
from numpy.lib.stride_tricks import as_strided
def take_per_row_strided(A, indx, num_elem=2):
m,n = A.shape
A.shape = (-1)
s0 = A.strides[0]
l_indx = indx + n*np.arange(len(indx))
out = as_strided(A, (len(A)-num_elem+1, num_elem), (s0,s0))[l_indx]
A.shape = m,n
return out
从200
矩阵每行获取2000x4000
的运行时测试
In [447]: A = np.random.randint(0,9,(2000,4000))
In [448]: indx = np.random.randint(0,4000-200,(2000))
In [449]: out1 = take_per_row(A, indx, 200)
In [450]: out2 = take_per_row_strided(A, indx, 200)
In [451]: np.allclose(out1, out2)
Out[451]: True
In [452]: %timeit take_per_row(A, indx, 200)
100 loops, best of 3: 2.14 ms per loop
In [453]: %timeit take_per_row_strided(A, indx, 200)
1000 loops, best of 3: 435 µs per loop
答案 2 :(得分:1)
您可以设置一个花哨的索引方法来查找正确的元素:
A = np.arange(10).reshape(2,-1)
x = np.stack([np.arange(A.shape[0])]* 2).T
y = np.stack([b, b+1]).T
A[x, y]
array([[0, 1],
[8, 9]])
与@ Kasramvd的np.take
回答比较:
slices = np.dstack([b, b+1])
np.take(A, slices)
array([[[0, 1],
[3, 4]]])
默认情况下, np.slice
取自展平的数组,而不是行。使用axis = 1
参数可以获得所有行的所有切片:
np.take(A, slices, axis = 1)
array([[[[0, 1],
[3, 4]]],
[[[5, 6],
[8, 9]]]])
哪个需要更多处理。