几年前,有人posted在活动状态食谱上进行比较,有三个python / NumPy函数;每个都接受相同的参数并返回相同的结果,距离矩阵。
其中两项来自公开来源;它们都是 - 或者它们在我看来是 - 惯用的numpy代码。创建距离矩阵所需的重复计算由numpy优雅的索引语法驱动。这是其中之一:
from numpy.matlib import repmat, repeat
def calcDistanceMatrixFastEuclidean(points):
numPoints = len(points)
distMat = sqrt(sum((repmat(points, numPoints, 1) -
repeat(points, numPoints, axis=0))**2, axis=1))
return distMat.reshape((numPoints,numPoints))
第三个使用单个循环创建了距离矩阵(考虑到只有1,000个2D点的距离矩阵,有一百万个条目,这显然是很多循环)。乍一看这个函数看起来像我在学习NumPy时编写的代码,我会编写NumPy代码,首先编写Python代码,然后逐行翻译。
活跃状态发布几个月后,在NumPy邮件列表的thread上发布并讨论了比较三者的性能测试结果。
循环函数实际上显着优于另外两个:
from numpy import mat, zeros, newaxis
def calcDistanceMatrixFastEuclidean2(nDimPoints):
nDimPoints = array(nDimPoints)
n,m = nDimPoints.shape
delta = zeros((n,n),'d')
for d in xrange(m):
data = nDimPoints[:,d]
delta += (data - data[:,newaxis])**2
return sqrt(delta)
线程中的一位参与者(Keir Mierle)提供了为什么这可能是真的原因:
我怀疑这会更快的原因是 它有更好的局部性,完全完成计算 在进入下一个工作集之前相对较小的工作集。一个衬里 必须重复将可能较大的MxN阵列拉入处理器。
通过这张海报自己的说法,他的评论只是一种怀疑,而且似乎没有进一步讨论。
关于如何解释这些结果的其他想法?
特别是,是否有一个有用的规则 - 关于何时循环以及何时索引 - 可以从此示例中提取作为编写numpy代码的指导?
对于那些不熟悉NumPy的人,或者没有看过代码的人,这种比较不是基于一个边缘情况 - 如果是这样的话,对我来说肯定不会那么有趣。相反,这种比较涉及在矩阵计算中执行共同任务的函数(即,在给定两个前因的情况下创建结果数组);而且,每个函数又由最常见的numpy内置函数组成。
答案 0 :(得分:11)
<强> TL; DR 上面的第二个代码只是循环遍历点的维数(3点通过for循环获得3D点),因此循环不多。上面第二个代码中的实际加速是它更好地利用Numpy的功能,以避免在找到点之间的差异时创建一些额外的矩阵。这减少了使用的内存和计算工作量。
更长的解释
我认为calcDistanceMatrixFastEuclidean2
函数可能会欺骗你的循环。它只是循环遍历点的维数。对于1D点,循环仅执行一次,对于2D,两次,对于3D,执行三次。这真的没有多少循环。
让我们分析一下代码,看看为什么一个比另一个快。 calcDistanceMatrixFastEuclidean
我将致电fast1
,calcDistanceMatrixFastEuclidean2
将fast2
。
fast1
基于Matlab的做法,repmap
函数证明了这一点。在这种情况下,repmap
函数创建一个数组,它只是一遍又一遍地重复的原始数据。但是,如果查看函数的代码,效率非常低。它使用许多Numpy函数(3 reshape
s和2 repeat
s)来执行此操作。 repeat
函数还用于创建一个包含原始数据的数组,每个数据项重复多次。如果我们的输入数据为[1,2,3]
,那么我们会从[1,2,3,1,2,3,1,2,3]
中减去[1,1,1,2,2,2,3,3,3]
。 Numpy必须在运行Numpy的C代码之间创建许多额外的矩阵,而这些代码本来是可以避免的。
fast2
使用更多Numpy的繁重工作而不会在Numpy调用之间创建尽可能多的矩阵。 fast2
遍历点的每个维度,进行减法并保持每个维度之间的平方差异的总计。只有在最后才是平方根。到目前为止,这可能听起来不如fast1
那么有效,但fast2
通过使用Numpy的索引来避免执行repmat
内容。让我们看一下1D的简单情况。 fast2
生成数据的一维数组,并从数据的2D(N x 1)数组中减去它。这会在每个点和所有其他点之间创建差异矩阵,而不必使用repmat
和repeat
,从而绕过创建大量额外数组。这是真正的速度差异在我看来的地方。 fast1
在矩阵之间创建了大量额外的东西(并且它们在计算上昂贵地创建)以找到点之间的差异,而fast2
更好地利用Numpy的力量来避免这些。
顺便说一句,这里有fast2
的更快版本:
def calcDistanceMatrixFastEuclidean3(nDimPoints):
nDimPoints = array(nDimPoints)
n,m = nDimPoints.shape
data = nDimPoints[:,0]
delta = (data - data[:,newaxis])**2
for d in xrange(1,m):
data = nDimPoints[:,d]
delta += (data - data[:,newaxis])**2
return sqrt(delta)
不同之处在于我们不再将delta创建为零矩阵。
答案 1 :(得分:1)
dis
为了好玩:
<强> dis.dis(calcDistanceMatrixFastEuclidean)
强>
2 0 LOAD_GLOBAL 0 (len)
3 LOAD_FAST 0 (points)
6 CALL_FUNCTION 1
9 STORE_FAST 1 (numPoints)
3 12 LOAD_GLOBAL 1 (sqrt)
15 LOAD_GLOBAL 2 (sum)
18 LOAD_GLOBAL 3 (repmat)
21 LOAD_FAST 0 (points)
24 LOAD_FAST 1 (numPoints)
27 LOAD_CONST 1 (1)
30 CALL_FUNCTION 3
4 33 LOAD_GLOBAL 4 (repeat)
36 LOAD_FAST 0 (points)
39 LOAD_FAST 1 (numPoints)
42 LOAD_CONST 2 ('axis')
45 LOAD_CONST 3 (0)
48 CALL_FUNCTION 258
51 BINARY_SUBTRACT
52 LOAD_CONST 4 (2)
55 BINARY_POWER
56 LOAD_CONST 2 ('axis')
59 LOAD_CONST 1 (1)
62 CALL_FUNCTION 257
65 CALL_FUNCTION 1
68 STORE_FAST 2 (distMat)
5 71 LOAD_FAST 2 (distMat)
74 LOAD_ATTR 5 (reshape)
77 LOAD_FAST 1 (numPoints)
80 LOAD_FAST 1 (numPoints)
83 BUILD_TUPLE 2
86 CALL_FUNCTION 1
89 RETURN_VALUE
<强> dis.dis(calcDistanceMatrixFastEuclidean2)
强>
2 0 LOAD_GLOBAL 0 (array)
3 LOAD_FAST 0 (nDimPoints)
6 CALL_FUNCTION 1
9 STORE_FAST 0 (nDimPoints)
3 12 LOAD_FAST 0 (nDimPoints)
15 LOAD_ATTR 1 (shape)
18 UNPACK_SEQUENCE 2
21 STORE_FAST 1 (n)
24 STORE_FAST 2 (m)
4 27 LOAD_GLOBAL 2 (zeros)
30 LOAD_FAST 1 (n)
33 LOAD_FAST 1 (n)
36 BUILD_TUPLE 2
39 LOAD_CONST 1 ('d')
42 CALL_FUNCTION 2
45 STORE_FAST 3 (delta)
5 48 SETUP_LOOP 76 (to 127)
51 LOAD_GLOBAL 3 (xrange)
54 LOAD_FAST 2 (m)
57 CALL_FUNCTION 1
60 GET_ITER
>> 61 FOR_ITER 62 (to 126)
64 STORE_FAST 4 (d)
6 67 LOAD_FAST 0 (nDimPoints)
70 LOAD_CONST 0 (None)
73 LOAD_CONST 0 (None)
76 BUILD_SLICE 2
79 LOAD_FAST 4 (d)
82 BUILD_TUPLE 2
85 BINARY_SUBSCR
86 STORE_FAST 5 (data)
7 89 LOAD_FAST 3 (delta)
92 LOAD_FAST 5 (data)
95 LOAD_FAST 5 (data)
98 LOAD_CONST 0 (None)
101 LOAD_CONST 0 (None)
104 BUILD_SLICE 2
107 LOAD_GLOBAL 4 (newaxis)
110 BUILD_TUPLE 2
113 BINARY_SUBSCR
114 BINARY_SUBTRACT
115 LOAD_CONST 2 (2)
118 BINARY_POWER
119 INPLACE_ADD
120 STORE_FAST 3 (delta)
123 JUMP_ABSOLUTE 61
>> 126 POP_BLOCK
8 >> 127 LOAD_GLOBAL 5 (sqrt)
130 LOAD_FAST 3 (delta)
133 CALL_FUNCTION 1
136 RETURN_VALUE
我不是dis
的专家,但似乎你必须更多地关注第一个调用的函数,以了解它们为什么需要一段时间。还有一个带有Python的性能分析器工具,cProfile
。