我需要通过位移的3D矢量移动3D数组用于算法。 截至目前,我正在使用这种(非常难看)的方法:
shiftedArray = np.roll(np.roll(np.roll(arrayToShift, shift[0], axis=0)
, shift[1], axis=1),
shift[2], axis=2)
哪个有效,但意味着我打3个卷! (根据我的分析,我算法时间的58%用于这些)
来自Numpy.roll的文档:
参数:
shift:intaxis:int,optional
在参数中没有提到类似数组...所以我不能进行多维滚动?
我以为我可以调用这种功能(听起来像Numpy的事情):
np.roll(arrayToShift,3DshiftVector,axis=(0,1,2))
也许我的阵列的扁平化版本被重塑了?但那我该如何计算换档向量?这种转变真的一样吗?
我很惊讶地发现没有简单的解决办法,因为我认为这是一件很常见的事情(好吧,不是 常见,但......)
那么我们如何 - 相对地 - 通过N维向量有效地移动ndarray?
注意:这个问题在2015年被问到,当numpy的滚动方法不支持此功能时。
答案 0 :(得分:5)
理论上,使用@Ed Smith所描述的scipy.ndimage.interpolation.shift
应该有效,但是由于一个开放的bug(https://github.com/scipy/scipy/issues/1323),它不会给出相当于{的多次调用的结果{ {1}}。
更新:numpy版本1.12.0中的“{3}}添加了”多卷“功能。这是一个二维示例,其中第一个轴滚动一个位置,第二个轴滚动三个位置:
np.roll
这会使下面的代码过时。我会把它留给后人。
下面的代码定义了一个我调用In [7]: x = np.arange(20).reshape(4,5)
In [8]: x
Out[8]:
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
In [9]: numpy.roll(x, [1, 3], axis=(0, 1))
Out[9]:
array([[17, 18, 19, 15, 16],
[ 2, 3, 4, 0, 1],
[ 7, 8, 9, 5, 6],
[12, 13, 14, 10, 11]])
的函数,它可以完成你想要的任务。下面是一个将其应用于形状为(500,500,500)的数组的示例:
multiroll
使用对In [64]: x = np.random.randn(500, 500, 500)
In [65]: shift = [10, 15, 20]
的多次调用来生成预期结果:
np.roll
使用In [66]: yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2)
:
multiroll
确认我们得到了预期的结果:
In [67]: ymulti = multiroll(x, shift)
对于此大小的数组,对In [68]: np.all(yroll3 == ymulti)
Out[68]: True
进行三次调用几乎比调用np.roll
慢三倍:
multiroll
以下是In [69]: %timeit yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2)
1 loops, best of 3: 1.34 s per loop
In [70]: %timeit ymulti = multiroll(x, shift)
1 loops, best of 3: 474 ms per loop
:
multiroll
答案 1 :(得分:3)
我认为scipy.ndimage.interpolation.shift
会根据docs
shift:float或sequence,optional
沿轴的移动。如果浮动,每个轴的移位相同。如果序列,shift应包含每个轴的一个值。
这意味着您可以执行以下操作,
from scipy.ndimage.interpolation import shift
import numpy as np
arrayToShift = np.reshape([i for i in range(27)],(3,3,3))
print('Before shift')
print(arrayToShift)
shiftVector = (1,2,3)
shiftedarray = shift(arrayToShift,shift=shiftVector,mode='wrap')
print('After shift')
print(shiftedarray)
哪个收益率,
Before shift
[[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 17]]
[[18 19 20]
[21 22 23]
[24 25 26]]]
After shift
[[[16 17 16]
[13 14 13]
[10 11 10]]
[[ 7 8 7]
[ 4 5 4]
[ 1 2 1]]
[[16 17 16]
[13 14 13]
[10 11 10]]]
答案 2 :(得分:0)
我相信roll
很慢,因为滚动数组不能表示为原始数据的视图,因为切片或重塑操作可以。因此每次都会复制数据。有关背景信息,请参阅:https://scipy-lectures.github.io/advanced/advanced_numpy/#life-of-ndarray
首先填充数组可能值得尝试(使用' wrap'模式),然后在填充数组上使用切片来获取shiftedArray
:http://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html
答案 3 :(得分:0)
wrap
模式下的{p> take
,我认为它不会更改内存中的数组。
以下是使用@ EdSmith输入的实现:
arrayToShift = np.reshape([i for i in range(27)],(3,3,3))
shiftVector = np.array((1,2,3))
ind = 3-shiftVector
np.take(np.take(np.take(arrayToShift,range(ind[0],ind[0]+3),axis=0,mode='wrap'),range(ind[1],ind[1]+3),axis=1,mode='wrap'),range(ind[2],ind[2]+3),axis=2,mode='wrap')
与OP的相同:
np.roll(np.roll(np.roll(arrayToShift, shift[0], axis=0) , shift[1], axis=1),shift[2], axis=2)
给出:
array([[[21, 22, 23],
[24, 25, 26],
[18, 19, 20]],
[[ 3, 4, 5],
[ 6, 7, 8],
[ 0, 1, 2]],
[[12, 13, 14],
[15, 16, 17],
[ 9, 10, 11]]])
答案 4 :(得分:0)
np.roll有多个维度。只是做
np.roll(arrayToShift, (shift[0], shift[1], shift[2]), axis=(0,1,2))
它不是很聪明,所以你必须指定轴