如何在2D数组中的索引给出的范围之间选择元素? numpy的

时间:2017-11-24 14:30:38

标签: python arrays numpy

假设我有一个数组

arr = np.random.randint(1000, size=(100,30))

我正在抨击和结束像

这样的指数
first= np.random.randint(5, size=(100,))
second = first + 20

如何在1D数组给出的范围之间选择2D数组中的数据?

目前我有一个for循环来完成任务

m =[]
for i,j in enumerate(arr):
    m.append(j[first[i]:second[i]])

np.array(m).shape

(100, 20)

如何在numpy中执行相同操作或向量化此解决方案?

1 个答案:

答案 0 :(得分:3)

使用broadcasting和一些masking -

def extract_rows(arr, first, second, fillval=0): 
    # Define the extent of output array
    L = (second.clip(max=arr.shape[1]-1) - first).max()

    # Get the ranged indices to cover the extent of cols for all rows
    idx = first[:,None] + range(L)

    # Now, the inices that are bounded by the limit of second values are to be
    # set as fillval. Similarly, the indices that go beoyn the extent of column
    # length of the input array are invalid as well. So, get the combined mask.
    invalid_mask = (idx >= arr.shape[1]) | ((second - first)[:,None] <= range(L))

    # Set invalid places in idx as zeros or just any value, but make sure those
    # are indexable into input array
    idx[invalid_mask] = 0

    # Finally index into input array with those and set the invalid ones in it
    # with fillval.
    return np.where(invalid_mask, fillval, arr[np.arange(len(idx))[:,None], idx])

示例运行 -

1]输入数组:

In [639]: arr
Out[639]: 
array([[87, 83, 36, 30, 58, 35, 85, 87],
       [17, 58, 51, 39, 56, 27, 97, 26],
       [33, 45,  1, 90, 87, 49, 30, 37],
       [92, 29, 17,  9, 81, 35, 47, 33],
       [61, 87, 22, 44, 97, 43, 96, 66],
       [47, 67, 28, 74, 50, 93, 22, 19],
       [77, 82, 35, 51, 25, 29, 25, 29],
       [95, 24, 70, 70, 34, 35, 50, 53],
       [53, 64, 84, 46, 21, 89, 44, 52],
       [92, 78, 21, 53, 53, 39,  7, 59]])

2]输入开始,每行停止索引:

In [640]: np.c_[first, second]
Out[640]: 
array([[ 0,  1],
       [ 0,  6],
       [ 2,  3],
       [ 3, 12],
       [ 2,  8],
       [ 2,  4],
       [ 4, 12],
       [ 0,  0],
       [ 2,  7],
       [ 4,  6]])

3]输出数组:

In [652]: extract_rows(arr, first, second)
Out[652]: 
array([[87,  0,  0,  0,  0,  0],
       [17, 58, 51, 39, 56, 27],
       [ 1,  0,  0,  0,  0,  0],
       [ 9, 81, 35, 47, 33,  0],
       [22, 44, 97, 43, 96, 66],
       [28, 74,  0,  0,  0,  0],
       [25, 29, 25, 29,  0,  0],
       [ 0,  0,  0,  0,  0,  0],
       [84, 46, 21, 89, 44,  0],
       [53, 39,  0,  0,  0,  0]])