优化简单的向量运算(python)

时间:2014-10-17 17:17:29

标签: python optimization numpy

在python中是否有一种快速的方法来执行一个简单的操作,从而产生一个A[i,j] = a[i] - b[j]的矩阵 给定两个数组a和b(长度相同,但这可能不相关)?

更准确地说,我所拥有的是在二维空间中的N个点,其位置存储在两个阵列dx和dy中,以及N个位置在tx和ty中的点。 我需要一个矩阵

A[i,j] = (dx[j]-tx[i])**2+(dy[j]-ty[i])**2

我想到的唯一方法是做

A = np.empty([nData,nData])
for i in range(nData):
        A[i] = (dx-tx[i])**2+(dy-ty[i])**2
return A

问题是这太慢了(nData会很大)。如果速度更快,欢迎使用任何符号更改。

(顺便说一下,x ** 2比x * x还是等价?)

2 个答案:

答案 0 :(得分:2)

尝试

>>> a = arange(1, 10)
>>> b = arange(1, 10)
>>> a.reshape(9, 1) - b.reshape(1, 9)
array([[ 0, -1, -2, -3, -4, -5, -6, -7, -8],
       [ 1,  0, -1, -2, -3, -4, -5, -6, -7],
       [ 2,  1,  0, -1, -2, -3, -4, -5, -6],
       [ 3,  2,  1,  0, -1, -2, -3, -4, -5],
       [ 4,  3,  2,  1,  0, -1, -2, -3, -4],
       [ 5,  4,  3,  2,  1,  0, -1, -2, -3],
       [ 6,  5,  4,  3,  2,  1,  0, -1, -2],
       [ 7,  6,  5,  4,  3,  2,  1,  0, -1],
       [ 8,  7,  6,  5,  4,  3,  2,  1,  0]])

该剪辑中发生的事情称为broadcasting,请在该页面上查找说明。如果你不惜一切代价避免显式循环,Numpy通常是最快的。谷歌搜索" numpy矢量化"应该为您提供详细信息。

翻译成您的示例,完整的公式是

(dx.reshape(len(dx), 1) - tx.reshape(1, len(tx)))**2 + \
(dy.reshape(len(dy), 1) - ty.reshape(1, len(ty)))**2 

答案 1 :(得分:2)

您想要计算点之间的所有成对平方欧几里德距离。最快的是使用scipy.distance.cdist

>>> import numpy as np
>>> from scipy.spatial.distance import cdist
>>> x = np.random.rand(10, 2)
>>> t = np.random.rand(8, 2)

>>> cdist(x, t, 'sqeuclidean')
array([[ 0.61048982,  0.04379578,  0.30763149],
       [ 0.02709455,  0.30235292,  0.25135934],
       [ 0.21249888,  0.14024951,  0.28441688],
       [ 0.39221412,  0.01994213,  0.17699239]])

如果你想在numpy中自己做。这样的事情可以解决问题:

>>> np.sum((x[:, None] - t)**2, axis=-1)
array([[ 0.61048982,  0.04379578,  0.30763149],
       [ 0.02709455,  0.30235292,  0.25135934],
       [ 0.21249888,  0.14024951,  0.28441688],
       [ 0.39221412,  0.01994213,  0.17699239]])

或者,使用单独的数组来表示x和y坐标:

>>> dx, dy = x.T
>>> tx, ty = t.T

>>> (dx[:, None] - tx)**2 + (dy[:, None] - ty)**2
array([[ 0.61048982,  0.04379578,  0.30763149],
       [ 0.02709455,  0.30235292,  0.25135934],
       [ 0.21249888,  0.14024951,  0.28441688],
       [ 0.39221412,  0.01994213,  0.17699239]])