我有一个函数,该函数处理维度为(h,w,200)的输入数组(数字200可以变化)并返回维度为(h,w,50,3)的数组。对于大小为512,512,200的输入数组,此功能大约需要0.8秒。
def myfunc(arr, n = 50):
#shape of arr is (h,w,200)
#output shape is (h,w,50,3)
#a1 is an array of length 50, I get them from a different
#function, which doesn't take much time. For simplicity, I fix it
#as np.arange(0,50)
a1 = np.arange(0,50)
output = np.stack((arr[:,:,a1],)*3, axis = -1)
return output
此预处理步骤在单个批处理中对约8个数组执行,因此加载一批数据需要8 * 0.8 = 6.4秒。有没有办法加快myfunc的计算?我可以使用像numba这样的库吗?
答案 0 :(得分:2)
我大约在同一时间得到
In [14]: arr = np.ones((512,512,200))
In [15]: timeit output = np.stack((arr[:,:,np.arange(50)],)*3, axis=-1)
681 ms ± 5.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [16]: np.stack((arr[:,:,np.arange(50)],)*3, axis=-1).shape
Out[16]: (512, 512, 50, 3)
详细查看时间。
首先执行索引/复制步骤,大约需要1/3的时间:
In [17]: timeit arr[:,:,np.arange(50)]
249 ms ± 306 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
还有stack
:
In [18]: %%timeit temp = arr[:,:,np.arange(50)]
...: output = np.stack([temp,temp,temp], axis=-1)
...:
...:
426 ms ± 367 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
stack
扩展尺寸,然后连接;让我们直接调用串联:
In [19]: %%timeit temp = arr[:,:,np.arange(50),None]
...: output = np.concatenate([temp,temp,temp], axis=-1)
...:
...:
430 ms ± 8.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
另一种方法是使用repeat
:
In [20]: %%timeit temp = arr[:,:,np.arange(50),None]
...: output = np.repeat(temp, 3, axis=-1)
...:
...:
531 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
因此,您的代码看起来就足够好了。
索引和连接已经使用了编译后的代码,因此我不希望numba
有所帮助(并不是我有很多经验)。
在新的前轴上堆叠更快(制造(3,512,512,50))
In [21]: %%timeit temp = arr[:,:,np.arange(50)]
...: output = np.stack([temp,temp,temp])
...:
...:
254 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
尽管随后的操作可能会更慢(如果它们需要复制和/或重新排序),则可以(便宜地)进行转置。全部copy
阵列时间中的普通output
大约在350毫秒左右。
受评论启发,我尝试播放广播作业:
In [101]: %%timeit temp = arr[:,:,np.arange(50)]
...: res = np.empty(temp.shape + (3,), temp.dtype)
...: res[...] = temp[...,None]
...:
...:
...:
337 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
同一个棒球场。
另一个技巧是使用strides
来制作“虚拟”副本:
In [74]: res1 = np.broadcast_to(arr, (3,)+arr.shape)
In [75]: res1.shape
Out[75]: (3, 512, 512, 200)
In [76]: res1.strides
Out[76]: (0, 819200, 1600, 8)
由于某些原因,这不适用于(512,512,200,3)
。它可能与broadcast_to
实现有关。也许有人可以尝试as_strided
。
尽管我可以很好地转置它:
np.broadcast_to(arr, (3,)+arr.shape).transpose(1,2,3,0)
无论如何,这要快很多:
In [82]: timeit res1 = np.broadcast_to(arr, (3,)+arr.shape)
10.4 µs ± 188 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
(但是制作一个copy
可以节省时间。)