cython做了什么numpy优化?

时间:2018-01-04 20:47:53

标签: python numpy cython

我有点惊讶地发现:

# fast_ops_c.pyx
cimport cython
cimport numpy as np

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
@cython.nonecheck(False)
def c_iseq_f1(np.ndarray[np.double_t, ndim=1, cast=False] x, double val):
    # Test (x==val) except gives NaN where x is NaN
    cdef np.ndarray[np.double_t, ndim=1] result = np.empty_like(x)
    cdef size_t i = 0
    cdef double _x = 0
    for i in range(len(x)):
        _x = x[i]
        result[i] = (_x-_x) + (_x==val)
    return result

比订单更快或更快:

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
@cython.nonecheck(False)
def c_iseq_f2(np.ndarray[np.double_t, ndim=1, cast=False] x, double val):
    cdef np.ndarray[np.double_t, ndim=1] result = np.empty_like(x)
    cdef size_t i = 0
    cdef double _x = 0
    for _x in x:        # Iterate over elements
        result[i] = (_x-_x) + (_x==val)
    return result

(对于大型数组)。我正在使用以下内容来测试性能:

# fast_ops.py
try:
    import pyximport
    pyximport.install(setup_args={"include_dirs": np.get_include()}, reload_support=True)
except Exception:
    pass

from fast_ops_c import *
import math
import nump as np

NAN = float("nan")

import unittest
class FastOpsTest(unittest.TestCase):

    def test_eq_speed(self):
        from timeit import timeit
        a = np.random.random(500000)
        a[1] = 2.
        a[2] = NAN

        a2 = c_iseq_f(a, 2.)
        def f1(): c_iseq_f2(a, 2.)
        def f2(): c_iseq_f1(a, 2.)

        # warm up
        [f1() for x in range(20)]
        [f2() for x in range(20)]

        n=1000
        dur = timeit(f1, number=n)
        print dur, "DUR1 s/iter", dur/n

        dur = timeit(f2, number=n)
        print dur, "DUR2 s/iter", dur/n

        dur = timeit(f1, number=n)

        print dur, "DUR1 s/iter", dur/n
        assert dur/n <= 0.005

        dur = timeit(f2, number=n)
        print dur, "DUR2 s/iter", dur/n

        print a2[:10]
        assert a2[0] == 0.
        assert a2[1] == 1.
        assert math.isnan(a2[2])

我猜测for _x in x被解释为执行x的python迭代器,而for i in range(n):被解释为C for循环,而x[i]被解释为C x[i]数组索引。

然而,我有点猜测并试图以身作则。 在其working with numpy(和here)文档中,Cython对于numpy的优化是什么,而不是什么。是否有 优化的指南。

类似地,下面假定连续的数组内存,比上述任何一个快得多。

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
def c_iseq_f(np.ndarray[np.double_t, ndim=1, cast=False, mode="c"] x not None, double val):
    cdef np.ndarray[np.double_t, ndim=1] result = np.empty_like(x)
    cdef size_t i = 0

    cdef double* _xp = &x[0]
    cdef double* _resultp = &result[0]
    for i in range(len(x)):
        _x = _xp[i]
        _resultp[i] = (_x-_x) + (_x==val)
    return result

2 个答案:

答案 0 :(得分:3)

令人惊讶的原因是x[i]看起来更加微妙。我们来看看下面的cython函数:

%%cython
def cy_sum(x):
   cdef double res=0.0
   cdef int i
   for i in range(len(x)):
         res+=x[i]
   return res

衡量其表现:

import numpy as np
a=np.random.random((2000,))
%timeit cy_sum(a)

>>>1000 loops, best of 3: 542 µs per loop

这很慢!如果你查看生成的C代码,你会看到x[i]使用__getitem()__功能,它需要C-double,创建一个python-Float对象,将它注册在垃圾中收集器,将其强制转换为C-double并销毁临时python-float。单个double的额外开销 - 添加!

让我们向cython说明x是一个类型化的内存视图:

%%cython
def cy_sum_memview(double[::1] x):
   cdef double res=0.0
   cdef int i
   for i in range(len(x)):
         res+=x[i]
   return res

表现更好:

%timeit cy_sum_memview(a)   
>>> 100000 loops, best of 3: 4.21 µs per loop

那发生了什么?因为cython知道,xtyped memory view(我宁愿在cython函数的签名中使用类型化的内存视图而不是numpy-array),所以它不再必须使用python-functions {{ 1}}但可以直接访问__getitem__值而无需创建中间python-float。

但回到numpy-arrays。 Numpy数组可以通过cython作为类型化内存视图进行解释,因此C-double可以转换为对底层内存的直接/快速访问。

那么范围呢?

x[i]

又慢了。因此,cython似乎不够聪明,无法通过直接/快速访问来替换for-range,并再次使用python-functions来产生开销。

我必须承认我和你一样惊讶,因为乍一看没有充分理由说为什么cython不能在for-range的情况下使用快速访问。但这就是它......

我不确定,这就是原因,但二维数组的情况并非如此简单。请考虑以下代码:

%%cython
cimport array
def cy_sum_memview_for(double[::1] x):
    cdef double res=0.0
    cdef double x_
    for x_ in x:
          res+=x_
    return res

%timeit cy_sum_memview_for(a)
>>> 1000 loops, best of 3: 736 µs per loop

此代码有效,因为import numpy as np a=np.zeros((5,1), dtype=int) for d in a: print(int(d)+1) 是一个1长的数组,因此可以通过d转换为Python标量。

然而,

int(d)

抛出,因为现在for d in a.T: print(int(d)+1) 的长度为d,因此无法将其转换为Python标量。

因为我们希望这个代码在cython化时具有与纯Python相同的行为,并且只能在运行时确定转换为int是否为Ok,我们首先使用Python对象5而且我们只能访问这个数组的内容。

答案 1 :(得分:2)

Cython可以将range(len(x))循环转换为几乎onLy C代码:

for i in range(len(x)):

生成的代码:

  __pyx_t_6 = PyObject_Length(((PyObject *)__pyx_v_x)); if (unlikely(__pyx_t_6 == -1)) __PYX_ERR(0, 17, __pyx_L1_error)
  for (__pyx_t_7 = 0; __pyx_t_7 < __pyx_t_6; __pyx_t_7+=1) {
    __pyx_v_i = __pyx_t_7;

但这仍然是Python:

 for _x in x:        # Iterate over elements

生成的代码:

  if (likely(PyList_CheckExact(((PyObject *)__pyx_v_x))) || PyTuple_CheckExact(((PyObject *)__pyx_v_x))) {
    __pyx_t_1 = ((PyObject *)__pyx_v_x); __Pyx_INCREF(__pyx_t_1); __pyx_t_6 = 0;
    __pyx_t_7 = NULL;
  } else {
    __pyx_t_6 = -1; __pyx_t_1 = PyObject_GetIter(((PyObject *)__pyx_v_x)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 12, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_1);
    __pyx_t_7 = Py_TYPE(__pyx_t_1)->tp_iternext; if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 12, __pyx_L1_error)
  }
  for (;;) {
    if (likely(!__pyx_t_7)) {
      if (likely(PyList_CheckExact(__pyx_t_1))) {
        if (__pyx_t_6 >= PyList_GET_SIZE(__pyx_t_1)) break;
        #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
        __pyx_t_3 = PyList_GET_ITEM(__pyx_t_1, __pyx_t_6); __Pyx_INCREF(__pyx_t_3); __pyx_t_6++; if (unlikely(0 < 0)) __PYX_ERR(0, 12, __pyx_L1_error)
        #else
        __pyx_t_3 = PySequence_ITEM(__pyx_t_1, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 12, __pyx_L1_error)
        __Pyx_GOTREF(__pyx_t_3);
        #endif
      } else {
        if (__pyx_t_6 >= PyTuple_GET_SIZE(__pyx_t_1)) break;
        #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
        __pyx_t_3 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_6); __Pyx_INCREF(__pyx_t_3); __pyx_t_6++; if (unlikely(0 < 0)) __PYX_ERR(0, 12, __pyx_L1_error)
        #else
        __pyx_t_3 = PySequence_ITEM(__pyx_t_1, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 12, __pyx_L1_error)
        __Pyx_GOTREF(__pyx_t_3);
        #endif
      }
    } else {
      __pyx_t_3 = __pyx_t_7(__pyx_t_1);
      if (unlikely(!__pyx_t_3)) {
        PyObject* exc_type = PyErr_Occurred();
        if (exc_type) {
          if (likely(exc_type == PyExc_StopIteration || PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear();
          else __PYX_ERR(0, 12, __pyx_L1_error)
        }
        break;
      }
      __Pyx_GOTREF(__pyx_t_3);
    }
    __pyx_t_8 = __pyx_PyFloat_AsDouble(__pyx_t_3); if (unlikely((__pyx_t_8 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 12, __pyx_L1_error)
    __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
    __pyx_v__x = __pyx_t_8;
/* … */
  }
  __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;

生成此输出通常是查找的最佳方式。