我正在努力解决矢量化问题,乍一看似乎很简单:
假设我有100张大小为(7,7)的图像,其中2个通道由numpy
大小数组(100,2,7,7)表示。我想在所有图像上提取小补丁(比如说大小(2,3,3)),但这些补丁并不是位于每个图像的同一个地方。贴片的位置由尺寸(2,100)的矩阵描述(每个图像一个x和一个y)。
我能够在所有图像上使用for循环,但这需要时间。
以下是一个示例代码:
data = np.arange(9800).reshape((100, 2, 7, 7))
size = 3
pos = np.random.randint(0, 7-size, (2, 100))
for i in range(a.shape[0]):
patch = data[i, :, pos[0,i]:(pos[0,i]+size), pos[1,i]:pos[1,i]+size]
换句话说,我想在没有for循环的情况下重现这段代码。 有没有人有这方面的线索?
答案 0 :(得分:1)
numpy.lib.stride_tricks.as_strided
的标准技巧应该适用于此:
das = np.lib.stride_tricks.as_strided(data, (5, 5, 100, 2, 3, 3), data.strides[-2:] + data.strides)
patches = das[(*pos, np.arange(100))]
您可以使用以下方式验证其有效:
for i in range(data.shape[0]):
assert np.all(patches[i]==data[i,:,pos[0,i]:pos[0,i]+size,pos[1,i]:pos[1,i]+size])