How to mix numpy slices to list of indices?

时间:2016-04-12 00:53:27

标签: python arrays numpy indexing slice

I have a numpy.array, called grid, with shape:

grid.shape = [N, M_1, M_2, ..., M_N]

The values of N, M_1, M_2, ..., M_N are known only after initialization.

For this example, let's say N=3 and M_1 = 20, M_2 = 17, M_3 = 9:

grid = np.arange(3*20*17*9).reshape(3, 20, 17, 9)

I am trying to loop over this array, as follows:

for indices, val in np.ndenumerate(grid[0]):
    print indices
    _some_func_with_N_arguments(*grid[:, indices])

At the first iteration, indices = (0, 0, 0) and:

grid[:, indices] # array with shape 3,3,17,9

whereas I want it to be an array of three elements only, much like:

grid[:, indices[0], indices[1], indices[2]] # array([   0, 3060, 6120])

which however I cannot implement like in the above line, because I don't know a-priori what is the length of indices.

I am using python 2.7, but a version-agnostic implementation is welcome :-)

3 个答案:

答案 0 :(得分:3)

我想你想要这样的东西:

In [134]: x=np.arange(24).reshape(4,3,2)

In [135]: x
Out[135]: 
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15],
        [16, 17]],

       [[18, 19],
        [20, 21],
        [22, 23]]])

In [136]: for i,j in np.ndindex(x[0].shape):
     ...:     print(i,j,x[:,i,j])
     ...:     
(0, 0, array([ 0,  6, 12, 18]))
(0, 1, array([ 1,  7, 13, 19]))
(1, 0, array([ 2,  8, 14, 20]))
(1, 1, array([ 3,  9, 15, 21]))
(2, 0, array([ 4, 10, 16, 22]))
(2, 1, array([ 5, 11, 17, 23]))

第一行是:

In [142]: x[:,0,0]
Out[142]: array([ 0,  6, 12, 18])

将索引元组解压缩为i,j并在x[:,i,j]中使用它是执行此索引的最简单方法。但是为了将其推广到其他数量的维度,我必须稍微使用元组。 x[i,j]x[(i,j)]相同。

In [147]: for ind in np.ndindex(x.shape[1:]):
     ...:     print(ind,x[(slice(None),)+ind])
     ...:     
((0, 0), array([ 0,  6, 12, 18]))
((0, 1), array([ 1,  7, 13, 19]))
...

enumerate

for ind,val in np.ndenumerate(x[0]):
    print(ind,x[(slice(None),)+ind])

答案 1 :(得分:2)

您可以手动将slice(None)添加到索引元组:

>>> grid.shape
(3, 20, 17, 9)
>>> indices
(19, 16, 8)
>>> grid[:,19,16,8]
array([3059, 6119, 9179])
>>> grid[(slice(None),) + indices]
array([3059, 6119, 9179])

有关详情,请参阅文档中的here

答案 2 :(得分:1)

我相信你要找的是select session_id, count(case when did_tap_on_screen then 1 end) > 0 from events group by session_id

grid[1:][grid[0]]