更快地计算Cython中的平方范数

时间:2017-11-18 07:21:03

标签: python python-3.x performance numpy cython

我想计算平方范数,可以写成 enter image description here

W是一个矩阵,有V行。你是一个载体。我有两个numpy对象W和你。

import numpy as np
import numpy.random as npr
V = 10
W = npr.normal(size=(V, 3))
u = npr.normal(size=(1,3))

如果我逐行计算,我可以这样做:

res = np.zeros(V)
for v in range(V):
    res[v] = (W[v] - u).dot((W[v] - u).transpose())

但是一旦V变大(比如5000),它可能会很慢,我需要一次又一次地重新计算它。 所以我尝试了矩阵乘法,但它没有成功,因为它不是逐行乘法。

((W - u).transpose()).dot(W - u)

如何在Numpy中快速计算平方范数?

我打算使用Cython,那么循环中的逐行乘法会比Numpy更快吗?我了解到Cython有parallelization,但是如果我在内部使用Numpy对象,似乎Cython无法并行化for loop(在20月11日添加:可能我不能在prange内使用Python对象但我可以使用Numpy对象。)

1 个答案:

答案 0 :(得分:3)

方法#1

您可以使用基于快速BLAS的np.dot权利和NumPy,而不需要任何循环,就像这样 -

res = (W**2).sum(1) + (u**2).sum(1) -2*W.dot(u.ravel())

引入np.einsum以获取W的{​​{1}}和np.inner的行方式摘要 -

u

方法说明

每次迭代都有res = np.einsum('ij,ij->i',W,W) + np.inner(u,u).ravel() -2*W.dot(u.ravel()) ,我们正在做内点产品。因为,我们对(W[v] - u).dot((W[v] - u)的所有行执行此操作,在利用广播时转换为W

现在,

((W - u)**2).sum(1)

因此,

(Xik-Yjk)**2 = Xik**2 + Yjk**2 - 2*Xik*Yjk

RHS的最后一个术语基本上是矩阵乘法,我们正在利用sum_k((Xik-Yjk)**2) = sum_k(Xik**2) + sum_k(Yjk**2) - 2*sum_k(Xik*Yjk)

方法#2

或者,使用更多np.dot,就像这样 -

np.einsum