向量化numpy的函数,它将一个向量的每个元素与另一个向量的每个元素相关联

时间:2013-12-06 09:20:31

标签: python numpy vector

我有一个函数,它接受两个向量的两个元素并计算标量值。如何使用numpy工具向量化这个函数,以便我可以编写

A = my_func(vec_a, vec_b)

其中A是维len(vec_a) x len(vec_b)的矩阵。我怎样才能做到这一点?或者我必须在my_func expilicitely迭代? 作为奖励:矩阵将非常稀疏,即my_func中的许多计算值为零。是否可以在实施中包含这种稀疏性?


根据要求,举例:

假设我有两个向量ab

a = numpy.array([...]) # length n
b = numpy.array([...]) # length m

现在,我想调用my_func(a,b)并让它返回一个稀疏矩阵,其密集表示将是

A = [
        [my_func(a[0], b[0]), my_func(a[0], b[1]), ..., my_func(a[0], b[n])],
        [my_func(a[1], b[0]), my_func(a[1], b[1]), ..., my_func(a[1], b[n])],
        ...
        [my_func(a[m], b[0]), my_func(a[m], b[1]), ..., my_func(a[m], b[n])]
]

当然,很多条目都是零。


根据要求,my_func功能。

# note, that each element of the above vectors is a 
# list itself, with 4 elements. 
def my_func(a, b):
    distance = sp.sqrt(sp.sum((a[1:] - b[1:])**2))
    rate = sp.exp(-2*distance/loclength)

    if a[0] < b[0]:
        rate *= sp.exp((a[0] - b[0])/kT)

    return rate if rate > cutoff else 0

1 个答案:

答案 0 :(得分:1)

您可以使用适当的广播来实现:

def my_func_vec(a, b):
    a = np.array(a, copy=False, ndmin=2)
    b = np.array(b, copy=False, ndmin=2)
    a = a[..., np.newaxis, :]
    b = b[..., np.newaxis, :, :]
    distance = np.sqrt(np.sum((a[..., 1:] - b[..., 1:])**2, axis=-1))
    rate = np.exp(-2*distance / loclength)
    mask = a[..., 0] < b[..., 0]
    rate[mask] *= np.exp((a[..., 0] - b[..., 0])[mask] / kT)
    mask = rate <= cutoff
    rate[mask] = 0
    return rate

要测试它,请设置一些虚拟值:

loclength = 1
kT = 1
cutoff = 0.25
a = np.random.rand(3, 5)
b = np.random.rand(4, 5)

现在:

>>> my_func_vec(a, b)
array([[ 0.34220076,  0.        ,  0.25392478,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.25953994,  0.        ,  0.        ]])

而不是:

>>> out = np.empty((3, 4))
>>> for r, j in enumerate(a):
...     for c, k in enumerate(b):
...         out[r, c] = my_func(j, k)
... 
>>> out
array([[ 0.34220076,  0.        ,  0.25392478,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.25953994,  0.        ,  0.        ]])