处理不同形状的数组时如何对齐矩阵(使用python)

时间:2013-07-30 12:02:44

标签: python arrays numpy matrix

我有以下数组(向量): l = [[0.483,0.923],[0.781,0.188],[0.446,0.564,0.796]]

我编写了以下代码行来计算向量之间的余弦相似度,并得到以下错误消息:ValueError:矩阵未对齐。

import numpy as np
import numpy.linalg as LA
l=[[0.483, 0.923], [0.781, 0.188], [0.446, 0.564, 0.796]]
cx = lambda a, b : round(np.inner(a, b)/(LA.norm(a)*LA.norm(b)), 2)
for v in l:
   for y in l:
    cosine=cx(v,y)
    print cosine

在将数组调整为相等长度(l = [[0.483,0.923],[0.781,0.188],[0.446,0.564]])时,我的代码工作正常。

现在的问题是如何在不调整数组形状的情况下使代码工作? (即如何对齐矩阵)。感谢您的任何建议。

1 个答案:

答案 0 :(得分:3)

由于您使用的是余弦相似度,因此ab应该有几何解释。 (余弦相似性是找到两个向量之间角度的余弦)。

长度为2的向量可以被认为存在于xy-plane中,长度为3的向量可以被认为存在于xyz-space中。 因此,平面中的向量[0.4, 0.9]可以被视为[0.4, 0.9, 0]中的三维向量xyz-space

如果这是合理的,那么在2D矢量和3D矢量之间获取内积相当于在简单地删除第三个分量后获取内积(因为任何东西乘以0都是0)。

因此,您可以通过这种方式定义cx

def cx(a, b) :
    a, b = (a, b) if len(a) < len(b) else (b, a)
    b = b[:len(a)]
    try:
        assert any(a)
        assert any(b)
    except AssertionError:
        print('either a or b is zero')
        # return 0  or 
        # raise 
    return round(np.inner(a, b)/(LA.norm(a)*LA.norm(b)), 2)

通过填写l中的缺失值可以获得更好的性能,因此可以将其制作成NumPy数组。然后,您可以立即将NumPy操作应用于整个数组并消除双Python for-loops

def cosine_similarity(l):
    inner = np.einsum('ij,kj -> ik', l, l)
    norm = np.sqrt(np.einsum('ij -> i', l*l))
    return inner/(norm*norm[:, np.newaxis])

def to_3d(l):
    return np.array([row+[0]*(3-len(row)) for row in l])

np.set_printoptions(precision=2)
print(cosine_similarity(to_3d(l)))

产量

[[ 1.    0.66  0.66]
 [ 0.66  1.    0.53]
 [ 0.66  0.53  1.  ]]

相比
def cx(a, b) :
    a, b = (a, b) if len(a) < len(b) else (b, a)
    b = b[:len(a)]
    return round(np.inner(a, b)/(LA.norm(a)*LA.norm(b)), 2)

def using_cx():
    for v in l:
       for y in l:
        cosine=cx(v,y)

timeit显示速度增加了11倍:

In [90]: %timeit using_cx()
1000 loops, best of 3: 380 us per loop

In [91]: %timeit cosine_similarity(to_3d(l))
10000 loops, best of 3: 32.6 us per loop

计算仍然是二次的 - 如果你想比较l中每一对可能的行,它总是如此。但它更快,因为NumPy函数是用C语言编写的,它往往比在Python循环中调用Python函数的等效代码更快。