numpy.dot()函数缓慢的原因以及如果使用自定义类如何缓解它们?

时间:2017-09-02 20:42:31

标签: python performance numpy matrix

我正在分析一个numpy dot产品电话。

numpy.dot(pseudo,pseudo)

pseudo是一个自定义对象的数组。定义为:

pseudo = numpy.array(
         [[PseudoBinary(1), PseudoBinary(0), PseudoBinary(1)],
          [PseudoBinary(1), PseudoBinary(0), PseudoBinary(0)],
          [PseudoBinary(1), PseudoBinary(0), PseudoBinary(1)]])

PseudoBinary是一个具有自定义乘法函数的类。它是OR而不是乘法。有关PseudoBinary定义的完整代码,请参见下文。

类型:

(Pdb) pseudo.dtype
dtype('O')

根据我的分析,伪点积比使用具有整数值的矩阵的点积慢约500倍。下面给出了分析代码的指针。

我对缓慢的原因以及是否有办法减轻它们感兴趣。

缓慢的一些原因可能是:

  • 伪的内存布局不会使用连续的内存。根据{{​​3}},numpy使用带有对象类型的指针。在矩阵乘法期间,可能发生一堆指针解引用,而不是直接从连续的存储器读取。

  • Numpy乘法可能不会使用优化的内部编译实现。 (BLAS,ATLAS等)根据this,应该采用各种条件来回退到优化的实施。使用自定义对象可能会破坏它们

还有其他因素在起作用吗?有任何改进建议吗?

所有这一切的出发点是this问题。在那里,OP正在寻找“定制点产品”。访问类似于点积运算的两个矩阵的元素的操作,但除了将列和行的相应元素相乘之外还执行其他操作。在this中,我推荐了一个覆盖__mul__函数的自定义对象。但是这种方法的numpy.dot性能非常慢。执行性能测量的代码也可以在该答案中看到。

显示PseudoBinary类和点积执行的代码。

#!/usr/bin/env python


 from __future__ import absolute_import
 from __future__ import print_function
 import numpy

 class PseudoBinary(object):
     def __init__(self,i):
         self.i = i

     def __mul__(self,rhs):
         return PseudoBinary(self.i or rhs.i)

     __rmul__ = __mul__
     __imul__ = __mul__

     def __add__(self,rhs):
         return PseudoBinary(self.i + rhs.i)

     __radd__ = __add__
     __iadd__ = __add__

     def __str__(self):
         return "P"+str(self.i)

     __repr__ = __str__

 base = numpy.array(
       [[1, 0, 1],
        [1, 0, 0],
        [1, 0, 1]])

 pseudo = numpy.array(
          [[PseudoBinary(1), PseudoBinary(0), PseudoBinary(1)],
           [PseudoBinary(1), PseudoBinary(0), PseudoBinary(0)],
           [PseudoBinary(1), PseudoBinary(0), PseudoBinary(1)]])

 baseRes = numpy.dot(base,base)
 pseudoRes = numpy.dot(pseudo,pseudo)

 print("baseRes\n",baseRes)
 print("pseudoRes\n",pseudoRes)

打印:

baseRes
 [[2 0 2]
 [1 0 1]
 [2 0 2]]
pseudoRes
 [[P3 P2 P2]
 [P3 P1 P2]
 [P3 P2 P2]]

2 个答案:

答案 0 :(得分:2)

你使用对象数组的任何任何都会很慢。 NumPy通常快速应用于对象数组的原因都没有。

  • 对象数组不能连续存储其元素。他们必须存储和取消引用指针。
    • 他们不知道他们需要为他们的元素分配多少空间。
    • 他们的元素可能并非都是相同的大小。
    • 您插入到对象数组中的元素已经在数组外部分配,并且无法复制它们。
  • 对象数组必须对所有元素操作执行动态分派。每次他们添加或增加两个元素时,他们必须弄清楚如何再做一遍。
  • 对象数组无法加速其元素的实现,例如您的缓慢解释__add____mul__
  • 对象数组无法避免与其元素操作关联的内存分配,例如在每个元素PseudoBinary或{上为该对象分配新的__dict__对象和新的__add__ {1}}。
  • 对象数组无法并行化操作,因为其元素上的所有操作都需要保留GIL。
  • 对象数组不能使用LAPACK或BLAS,因为没有任何Python数据类型的LAPACK或BLAS函数。

基本上,在没有NumPy的情况下进行Python数学运算的每个理由都适用于使用对象数组做任何事情。

至于如何改善表现?不要使用对象数组。使用常规数组,并根据NumPy提供的操作找到实现所需事物的方法,或者明确写出循环并使用Numba或Cython之类的东西来编译代码。

答案 1 :(得分:1)

dot是外部产品,后跟一个轴上的和。

对于pseudodot略高于产品等值的总和:

In [18]: timeit (pseudo[:,:,None]*pseudo[None,:,:]).sum(axis=1)
75.7 µs ± 3.14 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [19]: timeit  np.dot(pseudo, pseudo)
63.9 µs ± 1.91 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

对于basedot明显快于同等数量。

In [20]: timeit (base[:,:,None]*base[None,:,:]).sum(axis=1)
13.9 µs ± 24.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [21]: timeit  np.dot(base,base)
1.58 µs ± 53.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

因此,对于数值数组,dot可以将整个任务传递给优化的编译代码(BLAS或其他)。

通过创建数值对象数组并比较简单元素产品,我们可以进一步了解对象dtype如何影响速度:

In [28]: baso = base.astype(object)
In [29]: timeit base*base
766 ns ± 48.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [30]: timeit baso*baso
2.45 µs ± 73.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [31]: timeit pseudo*pseudo
13.7 µs ± 41.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

使用'或' (|)代替*,我们可以计算与pseudo相同的内容,但使用base

In [34]: (base[:,:,None] | base[None,:,:]).sum(axis=1)
Out[34]: 
array([[3, 2, 2],
       [3, 1, 2],
       [3, 2, 2]], dtype=int32)
In [35]: timeit (base[:,:,None] | base[None,:,:]).sum(axis=1)
15.1 µs ± 492 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)