numpy take不能使用切片索引

时间:2015-02-23 22:32:18

标签: python arrays numpy slice

根据take的numpy文档 它与“花式”索引(使用数组索引数组)的功能相同。但是,如果您需要沿给定轴的元素,则可以更容易使用。

然而,不像"花式"或使用切片作为索引的常规numpy索引似乎不受支持:

In [319]: A = np.arange(20).reshape(4, 5)

In [320]: A[..., 1:4]
Out[320]: 
array([[ 1,  2,  3],
       [ 6,  7,  8],
       [11, 12, 13],
       [16, 17, 18]])

In [321]: np.take(A, slice(1, 4), axis=-1)
TypeError: long() argument must be a string or a number, not 'slice'

使用仅在运行时已知的轴上的切片索引数组的最佳方法是什么?

3 个答案:

答案 0 :(得分:2)

我认为你的意思是:

In [566]: np.take(A, slice(1,4))
...
TypeError: int() argument must be a string or a number, not 'slice'

但是

np.take(A, np.r_[1:4])

就像A[1:4]

一样

np.insertnp.apply_along_axis这样的函数通过构造可能包含标量,切片和数组的索引元组来实现通用性。

ind = tuple([slice(1,4)])  # ndim terms to match array
A[ind]

np.tensordot是使用np.transpose将动作轴移动到最后的示例(供np.dot使用)。

另一个伎俩就是让所有剩余的盈余崩溃。轴重构为一个重塑形状。然后重新塑造。

答案 1 :(得分:2)

  

根据numpy docs for take,它与“花式”索引(索引数组使用数组)做同样的事情。

np.take的第二个参数必须是类似数组(数组,列表,元组等),而不是slice对象。您可以构造一个索引数组或列表来执行所需的切片:

a = np.arange(24).reshape(2, 3, 4)

np.take(a, slice(1, 4, 2), 2)
# TypeError: long() argument must be a string or a number, not 'slice'

np.take(a, range(1, 4, 2), 2)
# array([[[ 1,  3],
#         [ 5,  7],
#         [ 9, 11]],

#        [[13, 15],
#         [17, 19],
#         [21, 23]]])
  

使用仅在运行时已知的轴上的切片索引数组的最佳方法是什么?

我经常喜欢做的是使用np.rollaxis将轴索引到第一个索引,进行索引,然后将其回滚到原始位置。

例如,假设我想要沿着第3轴的3D数组的奇数切片:

sliced1 = a[:, :, 1::2]

如果我想在运行时指定要切片的轴,我可以这样做:

n = 2    # axis to slice along

sliced2 = np.rollaxis(np.rollaxis(a, n, 0)[1::2], 0, n + 1)

assert np.all(sliced1 == sliced2)

稍微解开那个单线:

# roll the nth axis to the 0th position
np.rollaxis(a, n, 0)

# index odd-numbered slices along the 0th axis
np.rollaxis(a, n, 0)[1::2]

# roll the 0th axis back so that it lies before position n + 1 (note the '+ 1'!)
np.rollaxis(np.rollaxis(a, n, 0)[1::2], 0, n + 1)

答案 2 :(得分:1)

最有效的方法似乎是A[(slice(None),) * axis + (slice(1, 4),)]

In [19]: import numpy as np
    ...: x = np.random.normal(0, 1, (50, 50, 50))
    ...: s = slice(10, 20)
    ...: axis = 2
    ...: 
    ...: 

In [20]: timeit np.rollaxis(np.rollaxis(x, axis, 0)[s], 0, axis + 1)
2.32 µs ± 15.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [21]: timeit x.take(np.arange(x.shape[axis])[s], axis)
28.5 µs ± 38.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [22]: timeit x[(slice(None),) * axis + (s,)]
321 ns ± 0.341 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)