numpy切片返回最后两个维度

时间:2014-11-24 17:52:08

标签: python numpy slice

基本上我正在寻找一种函数或语法,它允许我获得具有任意数量维度的n维numpy数组的最后两个维度的第一个“切片”。

我可以做到这一点,但它太难看了,如果有人发送了一个6d阵列怎么办?必须有像椭圆这样的numpy函数扩展为0,0,0,...而不是:,:,...,...

data_2d = np.ones(5**2).reshape(5,5)
data_3d = np.ones(5**3).reshape(5,5,5)
data_4d = np.ones(5**4).reshape(5,5,5,5)

def get_last2d(data):
    if data.ndim == 2:
        return data[:]
    if data.ndim == 3:
        return data[0, :]
    if data.ndim == 4:
        return data[0, 0, :]

np.array_equal(get_last2d(data_3d), get_last2d(data_4d))

谢谢, 科林

2 个答案:

答案 0 :(得分:2)

这个怎么样,

def get_last2d(data):
    if data.ndim <= 2:
        return data
    slc = [0] * (data.ndim - 2)
    slc += [slice(None), slice(None)]
    return data[slc]

答案 1 :(得分:0)

def get_last_2d(x):
    m,n = x.shape[-2:]
    return x.flat[:m*n].reshape(m,n)

这是有效的,因为展平数组会按照禁止变化的索引的顺序返回条目,而对于C风格的索引,最后的索引变化最快。因此,扁平阵列的前m * n个条目就是你想要的。