访问python中的第n个维度

时间:2017-08-02 13:40:08

标签: python arrays numpy multidimensional-array indexing

我希望能够轻松读取多维numpy数组的某些部分。对于任何访问第一维的数组都很容易(b[index])。另一方面,访问第六维是“硬”(特别是阅读)。

b[:,:,:,:,:,index] #the next person to read the code will have to count the :

有更好的方法吗? 特别是有一种方法,在编写程序时不知道轴吗?

编辑: 索引维度不一定是最后一个维度

5 个答案:

答案 0 :(得分:5)

您可以使用np.take。 例如:

b.take(index, axis=5)

答案 1 :(得分:5)

如果您想要一个视图并希望它快速,您可以手动创建索引:

arr[(slice(None), )*5 + (your_index, )]
#                   ^---- This is equivalent to 5 colons: `:, :, :, :, :`

这比np.take快得多,并且只比使用:索引的速度慢一点:

import numpy as np

arr = np.random.random((10, 10, 10, 10, 10, 10, 10))

np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr.take(4, axis=5))
np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr[(slice(None), )*5 + (4, )])
%timeit arr.take(4, axis=5)
# 18.6 ms ± 249 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit arr[(slice(None), )*5 + (4, )]
# 2.72 µs ± 39.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr[:, :, :, :, :, 4]
# 2.29 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

但也许不那么可读,所以如果你经常需要它,你可能应该把它放在一个有意义的名字的函数中:

def index_axis(arr, index, axis):
    return arr[(slice(None), )*axis + (index, )]

np.testing.assert_array_equal(arr[:,:,:,:,:,4], index_axis(arr, 4, axis=5))

%timeit index_axis(arr, 4, axis=5)
# 3.79 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

答案 2 :(得分:3)

MSeifert和kazemakase的答案之间的中间方式(可读性和时间)正在使用np.rollaxis

np.rollaxis(b, axis=5)[index]

测试解决方案:

import numpy as np

arr = np.random.random((10, 10, 10, 10, 10, 10, 10))

np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr.take(4, axis=5))
np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr[(slice(None), )*5 + (4, )])
np.testing.assert_array_equal(arr[:,:,:,:,:,4], np.rollaxis(arr, 5)[4])

%timeit arr.take(4, axis=5)
# 100 loops, best of 3: 4.44 ms per loop
%timeit arr[(slice(None), )*5 + (4, )]
# 1000000 loops, best of 3: 731 ns per loop
%timeit arr[:, :, :, :, :, 4]
# 1000000 loops, best of 3: 540 ns per loop
%timeit np.rollaxis(arr, 5)[4]
# 100000 loops, best of 3: 3.41 µs per loop

答案 3 :(得分:1)

本着@JürgMerlinSpaak rollaxis的精神,但速度更快而不是deprecated

b.swapaxes(0, axis)[index]

答案 4 :(得分:-1)

你可以说:

slice = b[..., index]