快速将3D numpy数组的片段组成3D子数组

时间:2018-09-30 17:15:53

标签: python arrays numpy indexing

我有一个3D numpy数组x。我想在轴0上获取每个切片的子集(每个子集的形状相同,但每个切片的起始索引和结束索引可能不同),然后将它们组合成一个单独的3D numpy数组。我可以用

来实现
import numpy as np

x = np.arange(24).reshape((3, 4, 2))
starts = [0, 2, 1]
ends = [2, 4, 3]

np.stack([x[i, starts[i]:ends[i]] for i in range(3)])

但是1)是否有任何方法可以通过花式索引在单个操作中执行此操作,并且2)这样可以加快处理速度?

1 个答案:

答案 0 :(得分:1)

我们可以利用基于np.lib.stride_tricks.as_stridedscikit-image's view_as_windows来获取滑动窗口。 More info on use of as_strided based view_as_windows

from skimage.util.shape import view_as_windows

L = 2 # ends[0]-starts[0]
w = view_as_windows(x,(1,L,1))[...,0,:,0]
out = w[np.arange(len(starts)), starts].swapaxes(1,2)

或者,利用broadcasting的紧凑版本将生成-

x[np.arange(len(starts))[:,None],np.asarray(starts)[:,None] + np.arange(L)]