使用numba优化Jaccard距离性能

时间:2017-04-24 19:47:50

标签: python performance numba

我正在尝试使用Numba

在python中实现尽可能快的jaccard距离版本
@nb.jit()
def nbjaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))

def jaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))


%%timeit
nbjaccard("compare this string","compare a different string")

- 12.4 ms

%%timeit 
jaccard("compare this string","compare a different string")

- 3.87 ms

为什么numba版需要更长的时间?有什么方法可以获得加速?

2 个答案:

答案 0 :(得分:2)

在我看来,允许对象模式numba函数(或者如果numba意识到整个函数使用python对象没有警告)是一个设计错误 - 因为这些是通常比纯python函数慢一点。

Numba非常强大(类型调度,你可以编写没有类型声明的python代码 - 与C扩展或Cython相比 - 真的很棒)但只有当它支持操作时才会这样:

这意味着“nopython”模式不支持任何未列在其中的操作。如果numba必须回到"object mode"那么要注意:

  

对象模式

     

Numba编译模式,它生成将所有值作为Python对象处理的代码,并使用Python C API对这些对象执行所有操作。在对象模式下编译的代码通常不会比Python解释代码运行得快,除非Numba编译器可以利用循环匹配。

这正是你案件中发生的事情:你完全以对象模式运作:

>>> nbjaccard.inspect_types()

[...]
# --- LINE 3 --- 
#   seq1 = arg(0, name=seq1)  :: pyobject
#   seq2 = arg(1, name=seq2)  :: pyobject
#   $0.1 = global(set: <class 'set'>)  :: pyobject
#   $0.3 = call $0.1(seq1)  :: pyobject
#   $0.4 = global(set: <class 'set'>)  :: pyobject
#   $0.6 = call $0.4(seq2)  :: pyobject
#   set1 = $0.3  :: pyobject
#   set2 = $0.6  :: pyobject

set1, set2 = set(seq1), set(seq2)

# --- LINE 4 --- 
#   $const0.7 = const(int, 1)  :: pyobject
#   $0.8 = global(len: <built-in function len>)  :: pyobject
#   $0.11 = set1 & set2  :: pyobject
#   $0.12 = call $0.8($0.11)  :: pyobject
#   $0.13 = global(float: <class 'float'>)  :: pyobject
#   $0.14 = global(len: <built-in function len>)  :: pyobject
#   $0.17 = set1 | set2  :: pyobject
#   $0.18 = call $0.14($0.17)  :: pyobject
#   $0.19 = call $0.13($0.18)  :: pyobject
#   $0.20 = $0.12 / $0.19  :: pyobject
#   $0.21 = $const0.7 - $0.20  :: pyobject
#   $0.22 = cast(value=$0.21)  :: pyobject
#   return $0.22

return 1 - len(set1 & set2) / float(len(set1 | set2))

正如您所看到的,每个操作都在Python对象上运行(如每行末尾的:: pyobject所示)。那是因为numba不支持strset。所以绝对没有什么可以在这里更快。除了你知道如何使用numpy数组或同类列表(数值类型)来解决这个问题。

在我的电脑上时间差异要大得多(使用numba 0.32.0)但个别时间要快得多 - 秒(10**-6秒)而不是秒(10**-3秒):

%timeit nbjaccard("compare this string","compare a different string")
10000 loops, best of 3: 84.4 µs per loop

%timeit jaccard("compare this string","compare a different string")
100000 loops, best of 3: 15.9 µs per loop

请注意,jit默认为lazy,因此第一次调用应该在执行时间之前完成 - 因为它包含编译代码的时间。

然而,你可以做一个优化:如果你知道两个集合的交集,你可以计算联合的长度(正如@Paul Hankin在他的现在删除的答案中提到的那样):

len(union) = len(set1) + len(set2) - len(intersection)

这将导致以下(纯python)代码:

def jaccard2(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    num_intersection = len(set1 & set2)
    return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)

%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 µs per loop

速度不快 - 但更好。

如果您使用

,还有一些改进空间
%load_ext cython

%%cython
def cyjaccard(seq1, seq2):
    cdef set set1 = set(seq1)
    cdef set set2 = set()

    cdef Py_ssize_t length_intersect = 0

    for char in seq2:
        if char not in set2:
            if char in set1:
                length_intersect += 1
            set2.add(char)

    return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))

%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 µs per loop

这里的主要优点是只需一次迭代就可以创建set2并计算交叉点中的元素数量(根本不需要创建交集)!

答案 1 :(得分:1)

当我计算这两个函数时,nbjaccard需要大约4.7微秒(在预热jit之后),普通python使用Numba 0.32.0需要大约3.2微秒。也就是说,我不希望numba在这种情况下为你提供任何加速,因为目前nopython模式基本上没有字符串支持。这意味着你要经历python对象层,这通常与没有jit运行没什么不同,除非numba可以做一些智能循环提升(即使用纯内在函数而不是python函数编译子块)。除了在numba情况下检查输入的类型之外,您可能只需支付一些小额开销。

我认为最重要的是,您尝试将numba用于目前尚未涵盖的用例。 Numba真正擅长的是处理numpy数组和数值标量值或可以推送到GPU的问题的操作。