我创造了这个玩具问题,反映了我更大的问题:
import numpy as np
ind = np.ones((3,2,4)) # shape=(3L, 2L, 4L)
dist = np.array([[0.1,0.3],[1,2],[0,1]]) # shape=(3L, 2L)
ans = np.array([np.dot(dist[i],ind[i]) for i in xrange(dist.shape[0])]) # shape=(3L, 4L)
print ans
""" prints:
[[ 0.4 0.4 0.4 0.4]
[ 3. 3. 3. 3. ]
[ 1. 1. 1. 1. ]]
"""
我想尽快做到这一点,所以使用numpy的函数计算ans
应该是最好的方法,因为这个操作很重,我的矩阵很大。
我看到this post,但形状不同,我无法理解我应该使用哪个axes
来解决这个问题。但是,我确信tensordot应该有答案。有什么建议吗?
编辑:我接受@ajcr's answer,但也请阅读我自己的答案,这可能有助于其他人......
答案 0 :(得分:15)
您可以使用np.einsum
进行操作,因为它可以非常小心地控制哪些轴相乘,哪些轴相加:
>>> np.einsum('ijk,ij->ik', ind, dist)
array([[ 0.4, 0.4, 0.4, 0.4],
[ 3. , 3. , 3. , 3. ],
[ 1. , 1. , 1. , 1. ]])
该函数将ind
的第一个轴中的条目与dist
(下标'i'
)的第一个轴中的条目相乘。同上,每个数组的第二个轴(下标'j'
)。我们告诉einsum不要返回3D数组,而是通过从输出下标中省略它来沿轴'j'
求和,从而返回一个2D数组。
np.tensordot
更难以应用于此问题。它自动汇总轴的产品。但是,我们需要两个产品组,但只能将一个加起来。
写np.tensordot(ind, dist, axes=[1, 1])
(如您链接的答案中)为您计算正确的值,但返回形状为(3, 4, 3)
的3D数组。如果您能负担较大阵列的内存成本,可以使用:
np.tensordot(ind, dist, axes=[1, 1])[0].T
这会为您提供正确的结果,但由于tensordot
首先会创建一个大于必要数组,einsum
似乎是更好的选择。
答案 1 :(得分:14)
关注@ajcr's great answer,我想确定哪种方法最快,所以我使用了import timeit
setup_code = """
import numpy as np
i,j,k = (300,200,400)
ind = np.ones((i,j,k)) #shape=(3L, 2L, 4L)
dist = np.random.rand(i,j) #shape=(3L, 2L)
"""
basic ="np.array([np.dot(dist[l],ind[l]) for l in xrange(dist.shape[0])])"
einsum = "np.einsum('ijk,ij->ik', ind, dist)"
tensor= "np.tensordot(ind, dist, axes=[1, 1])[0].T"
print "tensor - total time:", min(timeit.repeat(stmt=tensor,setup=setup_code,number=10,repeat=3))
print "basic - total time:", min(timeit.repeat(stmt=basic,setup=setup_code,number=10,repeat=3))
print "einsum - total time:", min(timeit.repeat(stmt=einsum,setup=setup_code,number=10,repeat=3))
:
tensor - total time: 6.59519493952
basic - total time: 0.159871203461
einsum - total time: 0.263569731028
令人惊讶的结果是:
memory error
所以很明显使用tensordot是错误的做法(更不用说i,j,k = (3000,200,400)
在更大的例子中,就像@ajcr所说的那样。)
由于此示例很小,我将矩阵大小更改为print "einsum - total time:", min(timeit.repeat(stmt=einsum,setup=setup_code,number=50,repeat=3))
print "basic - total time:", min(timeit.repeat(stmt=basic,setup=setup_code,number=50,repeat=3))
,翻转顺序只是为了确保它没有效果并设置另一个具有更高重复次数的测试:
einsum - total time: 13.3184077671
basic - total time: 8.44810031351
结果与第一次运行一致:
i,j,k = (30000,20,40)
然而,测试另一种尺寸增长 - einsum - total time: 0.325594117768
basic - total time: 0.926416766397
导致了以下结果:
i
请参阅注释以获取有关这些结果的说明。
道德是,在寻找特定问题的最快解决方案时,尝试根据类型和形状生成尽可能与原始数据类似的数据。在我的情况下,j,k
比{{1}}小得多,所以我留下了丑陋的版本,在这种情况下也是最快的版本。