为什么循环节拍索引在这里?

时间:2010-08-19 03:44:01

标签: python performance memory-management numpy

几年前,有人posted活动状态食谱上进行比较,有三个python / NumPy函数;每个都接受相同的参数并返回相同的结果,距离矩阵

其中两项来自公开来源;它们都是 - 或者它们在我看来是 - 惯用的numpy代码。创建距离矩阵所需的重复计算由numpy优雅的索引语法驱动。这是其中之一:

from numpy.matlib import repmat, repeat

def calcDistanceMatrixFastEuclidean(points):
  numPoints = len(points)
  distMat = sqrt(sum((repmat(points, numPoints, 1) - 
             repeat(points, numPoints, axis=0))**2, axis=1))
  return distMat.reshape((numPoints,numPoints))

第三个使用单个循环创建了距离矩阵(考虑到只有1,000个2D点的距离矩阵,有一百万个条目,这显然是很多循环)。乍一看这个函数看起来像我在学习NumPy时编写的代码,我会编写NumPy代码,首先编写Python代码,然后逐行翻译。

活跃状态发布几个月后,在NumPy邮件列表的thread上发布并讨论了比较三者的性能测试结果。

循环函数实际上显着优于另外两个:

from numpy import mat, zeros, newaxis

def calcDistanceMatrixFastEuclidean2(nDimPoints):
  nDimPoints = array(nDimPoints)
  n,m = nDimPoints.shape
  delta = zeros((n,n),'d')
  for d in xrange(m):
    data = nDimPoints[:,d]
    delta += (data - data[:,newaxis])**2
  return sqrt(delta)

线程中的一位参与者(Keir Mierle)提供了为什么这可能是真的原因:

  

我怀疑这会更快的原因是   它有更好的局部性,完全完成计算   在进入下一个工作集之前相对较小的工作集。一个衬里   必须重复将可能较大的MxN阵列拉入处理器。

通过这张海报自己的说法,他的评论只是一种怀疑,而且似乎没有进一步讨论。

关于如何解释这些结果的其他想法?

特别是,是否有一个有用的规则 - 关于何时循环以及何时索引 - 可以从此示例中提取作为编写numpy代码的指导?

对于那些不熟悉NumPy的人,或者没有看过代码的人,这种比较不是基于一个边缘情况 - 如果是这样的话,对我来说肯定不会那么有趣。相反,这种比较涉及在矩阵计算中执行共同任务的函数(即,在给定两个前因的情况下创建结果数组);而且,每个函数又由最常见的numpy内置函数组成。

2 个答案:

答案 0 :(得分:11)

<强> TL; DR 上面的第二个代码只是循环遍历点的维数(3点通过for循环获得3D点),因此循环不多。上面第二个代码中的实际加速是它更好地利用Numpy的功能,以避免在找到点之间的差异时创建一些额外的矩阵。这减少了使用的内存和计算工作量。

更长的解释 我认为calcDistanceMatrixFastEuclidean2函数可能会欺骗你的循环。它只是循环遍历点的维数。对于1D点,循环仅执行一次,对于2D,两次,对于3D,执行三次。这真的没有多少循环。

让我们分析一下代码,看看为什么一个比另一个快。 calcDistanceMatrixFastEuclidean我将致电fast1calcDistanceMatrixFastEuclidean2fast2

fast1基于Matlab的做法,repmap函数证明了这一点。在这种情况下,repmap函数创建一个数组,它只是一遍又一遍地重复的原始数据。但是,如果查看函数的代码,效率非常低。它使用许多Numpy函数(3 reshape s和2 repeat s)来执行此操作。 repeat函数还用于创建一个包含原始数据的数组,每个数据项重复多次。如果我们的输入数据为[1,2,3],那么我们会从[1,2,3,1,2,3,1,2,3]中减去[1,1,1,2,2,2,3,3,3]。 Numpy必须在运行Numpy的C代码之间创建许多额外的矩阵,而这些代码本来是可以避免的。

fast2使用更多Numpy的繁重工作而不会在Numpy调用之间创建尽可能多的矩阵。 fast2遍历点的每个维度,进行减法并保持每个维度之间的平方差异的总计。只有在最后才是平方根。到目前为止,这可能听起来不如fast1那么有效,但fast2通过使用Numpy的索引来避免执行repmat内容。让我们看一下1D的简单情况。 fast2生成数据的一维数组,并从数据的2D(N x 1)数组中减去它。这会在每个点和所有其他点之间创建差异矩阵,而不必使用repmatrepeat,从而绕过创建大量额外数组。这是真正的速度差异在我看来的地方。 fast1在矩阵之间创建了大量额外的东西(并且它们在计算上昂贵地创建)以找到点之间的差异,而fast2更好地利用Numpy的力量来避免这些。

顺便说一句,这里有fast2的更快版本:

def calcDistanceMatrixFastEuclidean3(nDimPoints):
  nDimPoints = array(nDimPoints)
  n,m = nDimPoints.shape
  data = nDimPoints[:,0]
  delta = (data - data[:,newaxis])**2
  for d in xrange(1,m):
    data = nDimPoints[:,d]
    delta += (data - data[:,newaxis])**2
  return sqrt(delta)

不同之处在于我们不再将delta创建为零矩阵。

答案 1 :(得分:1)

dis为了好玩:

<强> dis.dis(calcDistanceMatrixFastEuclidean)

  2           0 LOAD_GLOBAL              0 (len)
              3 LOAD_FAST                0 (points)
              6 CALL_FUNCTION            1
              9 STORE_FAST               1 (numPoints)

  3          12 LOAD_GLOBAL              1 (sqrt)
             15 LOAD_GLOBAL              2 (sum)
             18 LOAD_GLOBAL              3 (repmat)
             21 LOAD_FAST                0 (points)
             24 LOAD_FAST                1 (numPoints)
             27 LOAD_CONST               1 (1)
             30 CALL_FUNCTION            3

  4          33 LOAD_GLOBAL              4 (repeat)
             36 LOAD_FAST                0 (points)
             39 LOAD_FAST                1 (numPoints)
             42 LOAD_CONST               2 ('axis')
             45 LOAD_CONST               3 (0)
             48 CALL_FUNCTION          258
             51 BINARY_SUBTRACT
             52 LOAD_CONST               4 (2)
             55 BINARY_POWER
             56 LOAD_CONST               2 ('axis')
             59 LOAD_CONST               1 (1)
             62 CALL_FUNCTION          257
             65 CALL_FUNCTION            1
             68 STORE_FAST               2 (distMat)

  5          71 LOAD_FAST                2 (distMat)
             74 LOAD_ATTR                5 (reshape)
             77 LOAD_FAST                1 (numPoints)
             80 LOAD_FAST                1 (numPoints)
             83 BUILD_TUPLE              2
             86 CALL_FUNCTION            1
             89 RETURN_VALUE

<强> dis.dis(calcDistanceMatrixFastEuclidean2)

  2           0 LOAD_GLOBAL              0 (array)
              3 LOAD_FAST                0 (nDimPoints)
              6 CALL_FUNCTION            1
              9 STORE_FAST               0 (nDimPoints)

  3          12 LOAD_FAST                0 (nDimPoints)
             15 LOAD_ATTR                1 (shape)
             18 UNPACK_SEQUENCE          2
             21 STORE_FAST               1 (n)
             24 STORE_FAST               2 (m)

  4          27 LOAD_GLOBAL              2 (zeros)
             30 LOAD_FAST                1 (n)
             33 LOAD_FAST                1 (n)
             36 BUILD_TUPLE              2
             39 LOAD_CONST               1 ('d')
             42 CALL_FUNCTION            2
             45 STORE_FAST               3 (delta)

  5          48 SETUP_LOOP              76 (to 127)
             51 LOAD_GLOBAL              3 (xrange)
             54 LOAD_FAST                2 (m)
             57 CALL_FUNCTION            1
             60 GET_ITER
        >>   61 FOR_ITER                62 (to 126)
             64 STORE_FAST               4 (d)

  6          67 LOAD_FAST                0 (nDimPoints)
             70 LOAD_CONST               0 (None)
             73 LOAD_CONST               0 (None)
             76 BUILD_SLICE              2
             79 LOAD_FAST                4 (d)
             82 BUILD_TUPLE              2
             85 BINARY_SUBSCR
             86 STORE_FAST               5 (data)

  7          89 LOAD_FAST                3 (delta)
             92 LOAD_FAST                5 (data)
             95 LOAD_FAST                5 (data)
             98 LOAD_CONST               0 (None)
            101 LOAD_CONST               0 (None)
            104 BUILD_SLICE              2
            107 LOAD_GLOBAL              4 (newaxis)
            110 BUILD_TUPLE              2
            113 BINARY_SUBSCR
            114 BINARY_SUBTRACT
            115 LOAD_CONST               2 (2)
            118 BINARY_POWER
            119 INPLACE_ADD
            120 STORE_FAST               3 (delta)
            123 JUMP_ABSOLUTE           61
        >>  126 POP_BLOCK

  8     >>  127 LOAD_GLOBAL              5 (sqrt)
            130 LOAD_FAST                3 (delta)
            133 CALL_FUNCTION            1
            136 RETURN_VALUE

我不是dis的专家,但似乎你必须更多地关注第一个调用的函数,以了解它们为什么需要一段时间。还有一个带有Python的性能分析器工具,cProfile