我正在尝试使用Numba
在python中实现尽可能快的jaccard距离版本@nb.jit()
def nbjaccard(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
return 1 - len(set1 & set2) / float(len(set1 | set2))
def jaccard(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
return 1 - len(set1 & set2) / float(len(set1 | set2))
%%timeit
nbjaccard("compare this string","compare a different string")
- 12.4 ms
%%timeit
jaccard("compare this string","compare a different string")
- 3.87 ms
为什么numba版需要更长的时间?有什么方法可以获得加速?
答案 0 :(得分:2)
在我看来,允许纯对象模式numba函数(或者如果numba意识到整个函数使用python对象没有警告)是一个设计错误 - 因为这些是通常比纯python函数慢一点。
Numba非常强大(类型调度,你可以编写没有类型声明的python代码 - 与C扩展或Cython相比 - 真的很棒)但只有当它支持操作时才会这样:
这意味着“nopython”模式不支持任何未列在其中的操作。如果numba必须回到"object mode"那么要注意:
对象模式
Numba编译模式,它生成将所有值作为Python对象处理的代码,并使用Python C API对这些对象执行所有操作。在对象模式下编译的代码通常不会比Python解释代码运行得快,除非Numba编译器可以利用循环匹配。
这正是你案件中发生的事情:你完全以对象模式运作:
>>> nbjaccard.inspect_types()
[...]
# --- LINE 3 ---
# seq1 = arg(0, name=seq1) :: pyobject
# seq2 = arg(1, name=seq2) :: pyobject
# $0.1 = global(set: <class 'set'>) :: pyobject
# $0.3 = call $0.1(seq1) :: pyobject
# $0.4 = global(set: <class 'set'>) :: pyobject
# $0.6 = call $0.4(seq2) :: pyobject
# set1 = $0.3 :: pyobject
# set2 = $0.6 :: pyobject
set1, set2 = set(seq1), set(seq2)
# --- LINE 4 ---
# $const0.7 = const(int, 1) :: pyobject
# $0.8 = global(len: <built-in function len>) :: pyobject
# $0.11 = set1 & set2 :: pyobject
# $0.12 = call $0.8($0.11) :: pyobject
# $0.13 = global(float: <class 'float'>) :: pyobject
# $0.14 = global(len: <built-in function len>) :: pyobject
# $0.17 = set1 | set2 :: pyobject
# $0.18 = call $0.14($0.17) :: pyobject
# $0.19 = call $0.13($0.18) :: pyobject
# $0.20 = $0.12 / $0.19 :: pyobject
# $0.21 = $const0.7 - $0.20 :: pyobject
# $0.22 = cast(value=$0.21) :: pyobject
# return $0.22
return 1 - len(set1 & set2) / float(len(set1 | set2))
正如您所看到的,每个操作都在Python对象上运行(如每行末尾的:: pyobject
所示)。那是因为numba
不支持str
和set
。所以绝对没有什么可以在这里更快。除了你知道如何使用numpy数组或同类列表(数值类型)来解决这个问题。
在我的电脑上时间差异要大得多(使用numba 0.32.0)但个别时间要快得多 - 微秒(10**-6
秒)而不是毫秒(10**-3
秒):
%timeit nbjaccard("compare this string","compare a different string")
10000 loops, best of 3: 84.4 µs per loop
%timeit jaccard("compare this string","compare a different string")
100000 loops, best of 3: 15.9 µs per loop
请注意,jit
默认为lazy,因此第一次调用应该在执行时间之前完成 - 因为它包含编译代码的时间。
然而,你可以做一个优化:如果你知道两个集合的交集,你可以计算联合的长度(正如@Paul Hankin在他的现在删除的答案中提到的那样):
len(union) = len(set1) + len(set2) - len(intersection)
这将导致以下(纯python)代码:
def jaccard2(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
num_intersection = len(set1 & set2)
return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)
%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 µs per loop
速度不快 - 但更好。
如果您使用cython:
,还有一些改进空间%load_ext cython
%%cython
def cyjaccard(seq1, seq2):
cdef set set1 = set(seq1)
cdef set set2 = set()
cdef Py_ssize_t length_intersect = 0
for char in seq2:
if char not in set2:
if char in set1:
length_intersect += 1
set2.add(char)
return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))
%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 µs per loop
这里的主要优点是只需一次迭代就可以创建set2
并计算交叉点中的元素数量(根本不需要创建交集)!
答案 1 :(得分:1)
当我计算这两个函数时,nbjaccard
需要大约4.7微秒(在预热jit之后),普通python使用Numba 0.32.0需要大约3.2微秒。也就是说,我不希望numba在这种情况下为你提供任何加速,因为目前nopython
模式基本上没有字符串支持。这意味着你要经历python对象层,这通常与没有jit运行没什么不同,除非numba可以做一些智能循环提升(即使用纯内在函数而不是python函数编译子块)。除了在numba情况下检查输入的类型之外,您可能只需支付一些小额开销。
我认为最重要的是,您尝试将numba用于目前尚未涵盖的用例。 Numba真正擅长的是处理numpy数组和数值标量值或可以推送到GPU的问题的操作。