基本上我正在寻找一种函数或语法,它允许我获得具有任意数量维度的n维numpy数组的最后两个维度的第一个“切片”。
我可以做到这一点,但它太难看了,如果有人发送了一个6d阵列怎么办?必须有像椭圆这样的numpy函数扩展为0,0,0,...而不是:,:,...,...
data_2d = np.ones(5**2).reshape(5,5)
data_3d = np.ones(5**3).reshape(5,5,5)
data_4d = np.ones(5**4).reshape(5,5,5,5)
def get_last2d(data):
if data.ndim == 2:
return data[:]
if data.ndim == 3:
return data[0, :]
if data.ndim == 4:
return data[0, 0, :]
np.array_equal(get_last2d(data_3d), get_last2d(data_4d))
谢谢, 科林
答案 0 :(得分:2)
这个怎么样,
def get_last2d(data):
if data.ndim <= 2:
return data
slc = [0] * (data.ndim - 2)
slc += [slice(None), slice(None)]
return data[slc]
答案 1 :(得分:0)
def get_last_2d(x):
m,n = x.shape[-2:]
return x.flat[:m*n].reshape(m,n)
这是有效的,因为展平数组会按照禁止变化的索引的顺序返回条目,而对于C风格的索引,最后的索引变化最快。因此,扁平阵列的前m * n个条目就是你想要的。