我正在尝试转换一些代码以与Numba一起运行。 np.einsum
不被支持,因此我正尝试用Numba支持的功能替换它。
我部分了解了np.einsum
的工作原理,例如,我明白了:
x, y, z = 3, 2, 4
A = np.arange(x * y * z).reshape(x, y, z)
B = np.arange(x * y).reshape(x, y)
C = np.einsum('ijk,kj->ki', A.T, B)
等效于:
C = np.sum(A.T * B.T, axis=1).T
例如我使用ijk
和3D规范索引,但是现在我有以下我无法理解的表达式:
C = np.einsum('aij,jka->ajk', A, B)
索引'a'
是什么意思?
使用乘法,求和和转置的等效变换是什么?
答案 0 :(得分:3)
您在坐标轴字符串中使用的字母没什么关系(但请参阅本文的底部),例如,我们可以将z
替换为a
:
>>> A = np.arange(3*4*5).reshape(3,4,5)
>>> B = np.arange(5*2*3).reshape(5,2,3)
>>>
>>> np.einsum('aij,jka->ajk',A,B)
array([[[ 0, 90],
[ 204, 306],
[ 456, 570],
[ 756, 882],
[1104, 1242]],
[[ 110, 440],
[ 798, 1140],
[1534, 1888],
[2318, 2684],
[3150, 3528]],
[[ 380, 950],
[1552, 2134],
[2772, 3366],
[4040, 4646],
[5356, 5974]]])
>>> np.einsum('zij,jkz->zjk',A,B)
array([[[ 0, 90],
[ 204, 306],
[ 456, 570],
[ 756, 882],
[1104, 1242]],
[[ 110, 440],
[ 798, 1140],
[1534, 1888],
[2318, 2684],
[3150, 3528]],
[[ 380, 950],
[1552, 2134],
[2772, 3366],
[4040, 4646],
[5356, 5974]]])
没有einsum
的情况:
>>> A.sum(1)[..., None]*B.transpose(2,0,1)
array([[[ 0, 90],
[ 204, 306],
[ 456, 570],
[ 756, 882],
[1104, 1242]],
[[ 110, 440],
[ 798, 1140],
[1534, 1888],
[2318, 2684],
[3150, 3528]],
[[ 380, 950],
[1552, 2134],
[2772, 3366],
[4040, 4646],
[5356, 5974]]])
索引字母的身份很重要,因为它们假定输出轴是字母顺序的,因此输出轴是隐式的
>>> A = np.ones((2,1))
>>> np.einsum('ab', A)
array([[1.],
[1.]])
>>> np.einsum('zb', A)
array([[1., 1.]])