Python:更快地计算两个dicts的余弦相似度

时间:2014-03-13 14:35:53

标签: python scipy

我有两个词:

d1 = {1234: 4, 125: 7, ...}
d2 = {1234: 8, 1288: 5, ...}

dicts的长度从10到40000不等。为了计算余弦相似度,我使用这个函数:

from scipy.linalg import norm
def simple_cosine_sim(a, b):
    if len(b) < len(a):
        a, b = b, a

    res = 0
    for key, a_value in a.iteritems():
        res += a_value * b.get(key, 0)
    if res == 0:
        return 0

    try:
        res = res / norm(a.values()) / norm(b.values())
    except ZeroDivisionError:
        res = 0
    return res 

是否可以更快地计算相似度?

UPD :使用Cython + 15%的速度重写代码。感谢@Davidmh

from scipy.linalg import norm

def fast_cosine_sim(a, b):
    if len(b) < len(a):
        a, b = b, a

    cdef long up, key
    cdef int a_value, b_value

    up = 0
    for key, a_value in a.iteritems():
        b_value = b.get(key, 0)
        up += a_value * b_value
    if up == 0:
        return 0
    return up / norm(a.values()) / norm(b.values())

2 个答案:

答案 0 :(得分:1)

如果索引不是太高,您可以将每个字典转换为数组。如果它们非常大,则可以使用稀疏数组。然后,余弦相似性只会使它们两者相乘。如果您必须重复使用相同的字典进行多次计算,则此方法的效果最佳。

如果这不是一个选项,只要你注释a_value和b_value,Cython应该非常快。

修改 看看你的Cython重写,我看到了一些改进。第一件事是做一个cython -a来生成编辑的HTML报告,看看哪些东西已加速,哪些东西没有。首先,你定义&#34; up&#34;只要,但你总结整数。此外,在您的示例中,键是整数,但您将它们声明为double。另一个简单的方法是输入输入作为dicts。

另外,检查C代码时,似乎有一些无法检查,您可以使用@ cython.nonechecks(False)禁用它。

实际上,字典的实现非常有效,所以在一般情况下,你可能不会比这更好。如果您需要充分利用代码,可能需要使用C API替换一些调用:http://docs.python.org/2/c-api/dict.html

cpython.PyDict_GetItem(a, key)

但是,你将负责引用计数并从PyObject *转换为int,以获得可疑的性能提升。

任何方式,代码的开头都是这样的:

cimport cython

@cython.nonecheck(False)
@cython.cdivision(True)
def fast_cosine_sim(dict a, dict b):
    if len(b) < len(a):
        a, b = b, a

    cdef int up, key
    cdef int a_value, b_value

另一个问题是:你的指挥官是否很大?因为如果不是这样,那么规范的计算实际上可能是一个重要的开销。

<强> EDIT2: 另一种可能的方法是仅查看必要的密钥。说:

from scipy.linalg import norm
cimport cython

@cython.nonecheck(False)
@cython.cdivision(True)
def fast_cosine_sim(dict a, dict b):
    cdef int up, key
    cdef int a_value, b_value

    up = 0
    for key in set(a.keys()).intersection(b.keys()):
        a_value = a[key]
        b_value = b[key]
        up += a_value * b_value
    if up == 0:
        return 0
    return up / norm(a.values()) / norm(b.values())

这在Cython中非常有效。实际性能可能取决于密钥之间的重叠程度。

答案 1 :(得分:1)

从算法的角度来看,没有。你已经处于复杂性O(N)。但是,你可以使用一些计算技巧。

您可以使用多处理模块将a_value * b.get(key, 0)乘法分配给多个工作人员,从而利用您拥有的所有计算机核心。请注意,使用线程不会产生此效果,因为Python具有全局解释器锁。

最简单的方法是使用Pool对象的multiproccess.Poolmap方法。

我强烈建议使用Python内置的cProfiler来检查代码中的热点。这非常容易。跑吧:

python -m cProfile myscript.py