numba范围迭代的更好模式?

时间:2018-03-02 19:18:26

标签: python numba

看看numba解包rangexrange的方式,很明显有很多事情不仅仅等同于for循环。 例如:

import unittest
class FastOpsTest(unittest.TestCase):

    def test_numba_sum(self):
        arr = np.array([1,2,3], dtype=float)
        self.assertEquals(FastOpsTest.sum(arr), 6)
        FastOpsTest.sum.inspect_types()


    @staticmethod
    @jit(nopython=True, nogil=True)
    def sum(arr):
        N = len(arr)
        result = 0
        for i in xrange(N): # range() behaves similarly
            result += arr[i]
        return result

...给出下面的VM代码。

for循环实际上是调用python范围和xrange C函数。更糟糕的是,看起来它在堆上分配内存(我认为这就是dels所做的)。 这似乎非常不理想,特别是对于嵌套循环。

除了手工重构之外:

i =0
while i!=N:
   ...
   i+=1

是否有更好的模式来优化numba中的循环?

# File: /home/userx/py/fast_ops.py
# --- LINE 176 --- 
# label 0
#   del $0.1
#   del $0.3
#   del $const0.4

@staticmethod

# --- LINE 177 --- 

@jit(nopython=True, nogil=True)

# --- LINE 178 --- 

def sum(arr):

    # --- LINE 179 --- 
    #   arr = arg(0, name=arr)  :: array(float64, 1d, C)
    #   $0.1 = global(len: <built-in function len>)  :: Function(<built-in function len>)
    #   $0.3 = call $0.1(arr, kws=[], args=[Var(arr, /home/userx/py/fast_ops.py (179))], func=$0.1, vararg=None)  :: (array(float64, 1d, C),) -> int64
    #   N = $0.3  :: int64

    N = len(arr)

    # --- LINE 180 --- 
    #   $const0.4 = const(int, 0)  :: int64
    #   result = $const0.4  :: float64
    #   jump 18
    # label 18

    result = 0

    # --- LINE 181 --- 
    #   jump 21
    # label 21
    #   $21.1 = global(xrange: <type 'xrange'>)  :: Function(<built-in function range>)
    #   $21.3 = call $21.1(N, kws=[], args=[Var(N, /home/userx/py/fast_ops.py (179))], func=$21.1, vararg=None)  :: (int64,) -> range_state_int64
    #   del N
    #   del $21.1
    #   $21.4 = getiter(value=$21.3)  :: range_iter_int64
    #   del $21.3
    #   $phi31.1 = $21.4  :: range_iter_int64
    #   del $21.4
    #   jump 31
    # label 31
    #   $31.2 = iternext(value=$phi31.1)  :: pair<int64, bool>
    #   $31.3 = pair_first(value=$31.2)  :: int64
    #   $31.4 = pair_second(value=$31.2)  :: bool
    #   del $31.2
    #   $phi54.1 = $31.3  :: int64
    #   del $phi54.1
    #   $phi54.2 = $phi31.1  :: range_iter_int64
    #   del $phi54.2
    #   $phi34.1 = $31.3  :: int64
    #   del $31.3
    #   branch $31.4, 34, 54
    # label 34
    #   del $31.4
    #   i = $phi34.1  :: int64
    #   del $phi34.1
    #   del i
    #   del $34.5
    #   del $34.6

    for i in xrange(N):

        # --- LINE 182 --- 
        #   $34.5 = getitem(index=i, value=arr)  :: float64
        #   $34.6 = inplace_binop(static_rhs=<object object at 0x7fb921f7cbc0>, rhs=$34.5, immutable_fn=+, lhs=result, static_lhs=<object object at 0x7fb921f7cbc0>, fn=+=)  :: float64
        #   result = $34.6  :: float64
        #   jump 31
        # label 54
        #   del arr
        #   del $phi34.1
        #   del $phi31.1
        #   del $31.4
        #   jump 55
        # label 55
        #   del result

        result += arr[i]

    # --- LINE 183 --- 
    #   $55.2 = cast(value=result)  :: float64
    #   return $55.2

    return result

1 个答案:

答案 0 :(得分:1)

inspect_types()返回Numba IR - 我不熟悉它,但我认为没有任何理由可以预期它会贴近实际执行的地图。

在抽象中向下工作,您还可以使用inspect_llvm()方法查看LLVM IR,并使用inspect_asm()查看实际执行的内容。在这种情况下,查看LLVM IR可以很清楚地编译成一个非常简单的for循环 - 我相信标签B24:是内循环。

print(next(iter(FastOpsTest.sum.inspect_llvm().values())))

# some parts ommitted
define i32 @"_ZN8__main__11FastOpsTest7sum$242E5ArrayIdLi1E1C7mutable7alignedE"(double* noalias nocapture %retptr, { i8*, i32 }** noalias nocapture readnone %excinfo, i8* noalias nocapture readnone %env, i8* nocapture readnone %arg.arr.0, i8* nocapture readnone %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* nocapture readonly %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) local_unnamed_addr #0 {
entry:
  %.98 = icmp sgt i64 %arg.arr.5.0, 0
  br i1 %.98, label %B24.preheader, label %B40

B24.preheader:                                    ; preds = %entry
  %0 = add i64 %arg.arr.5.0, 1
  br label %B24

B24:                                              ; preds = %B24.preheader, %B24
  %lsr.iv8 = phi double* [ %arg.arr.4, %B24.preheader ], [ %scevgep, %B24 ]
  %lsr.iv = phi i64 [ %0, %B24.preheader ], [ %lsr.iv.next, %B24 ]
  %result.07 = phi double [ %.250, %B24 ], [ 0.000000e+00, %B24.preheader ]
  %.242 = load double, double* %lsr.iv8, align 8
  %.250 = fadd double %result.07, %.242
  %lsr.iv.next = add i64 %lsr.iv, -1
  %scevgep = getelementptr double, double* %lsr.iv8, i64 1
  %.143 = icmp sgt i64 %lsr.iv.next, 1
  br i1 %.143, label %B24, label %B40

B40:                                              ; preds = %B24, %entry
  %result.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %.250, %B24 ]
  store double %result.0.lcssa, double* %retptr, align 8
  ret i32 0
}