我有两个词:
d1 = {1234: 4, 125: 7, ...}
d2 = {1234: 8, 1288: 5, ...}
dicts的长度从10到40000不等。为了计算余弦相似度,我使用这个函数:
from scipy.linalg import norm
def simple_cosine_sim(a, b):
if len(b) < len(a):
a, b = b, a
res = 0
for key, a_value in a.iteritems():
res += a_value * b.get(key, 0)
if res == 0:
return 0
try:
res = res / norm(a.values()) / norm(b.values())
except ZeroDivisionError:
res = 0
return res
是否可以更快地计算相似度?
UPD :使用Cython + 15%的速度重写代码。感谢@Davidmh
from scipy.linalg import norm
def fast_cosine_sim(a, b):
if len(b) < len(a):
a, b = b, a
cdef long up, key
cdef int a_value, b_value
up = 0
for key, a_value in a.iteritems():
b_value = b.get(key, 0)
up += a_value * b_value
if up == 0:
return 0
return up / norm(a.values()) / norm(b.values())
答案 0 :(得分:1)
如果索引不是太高,您可以将每个字典转换为数组。如果它们非常大,则可以使用稀疏数组。然后,余弦相似性只会使它们两者相乘。如果您必须重复使用相同的字典进行多次计算,则此方法的效果最佳。
如果这不是一个选项,只要你注释a_value和b_value,Cython应该非常快。
修改强> 看看你的Cython重写,我看到了一些改进。第一件事是做一个cython -a来生成编辑的HTML报告,看看哪些东西已加速,哪些东西没有。首先,你定义&#34; up&#34;只要,但你总结整数。此外,在您的示例中,键是整数,但您将它们声明为double。另一个简单的方法是输入输入作为dicts。
另外,检查C代码时,似乎有一些无法检查,您可以使用@ cython.nonechecks(False)禁用它。
实际上,字典的实现非常有效,所以在一般情况下,你可能不会比这更好。如果您需要充分利用代码,可能需要使用C API替换一些调用:http://docs.python.org/2/c-api/dict.html
cpython.PyDict_GetItem(a, key)
但是,你将负责引用计数并从PyObject *转换为int,以获得可疑的性能提升。
任何方式,代码的开头都是这样的:
cimport cython
@cython.nonecheck(False)
@cython.cdivision(True)
def fast_cosine_sim(dict a, dict b):
if len(b) < len(a):
a, b = b, a
cdef int up, key
cdef int a_value, b_value
另一个问题是:你的指挥官是否很大?因为如果不是这样,那么规范的计算实际上可能是一个重要的开销。
<强> EDIT2:强> 另一种可能的方法是仅查看必要的密钥。说:
from scipy.linalg import norm
cimport cython
@cython.nonecheck(False)
@cython.cdivision(True)
def fast_cosine_sim(dict a, dict b):
cdef int up, key
cdef int a_value, b_value
up = 0
for key in set(a.keys()).intersection(b.keys()):
a_value = a[key]
b_value = b[key]
up += a_value * b_value
if up == 0:
return 0
return up / norm(a.values()) / norm(b.values())
这在Cython中非常有效。实际性能可能取决于密钥之间的重叠程度。
答案 1 :(得分:1)
从算法的角度来看,没有。你已经处于复杂性O(N)。但是,你可以使用一些计算技巧。
您可以使用多处理模块将a_value * b.get(key, 0)
乘法分配给多个工作人员,从而利用您拥有的所有计算机核心。请注意,使用线程不会产生此效果,因为Python具有全局解释器锁。
最简单的方法是使用Pool对象的multiproccess.Pool
和map
方法。
我强烈建议使用Python内置的cProfiler来检查代码中的热点。这非常容易。跑吧:
python -m cProfile myscript.py