我有两个矩阵A,B,NxKxD维度,我想得到矩阵C,NxKxDxD维度,其中C [n,k] = A [n,k] x B [n,k] .T(这里&# 34; x"表示维度为Dx1和1xD的矩阵的乘积,因此结果必须是DxD维度),所以现在我的代码看起来像这样(这里A = B = X):
def square(X):
out = np.zeros((N, K, D, D))
for n in range(N):
for k in range(K):
out[n, k] = np.dot(X[n, k, :, np.newaxis], X[n, k, np.newaxis, :])
return out
对于大N和K来说可能会很慢,因为python是循环的。有没有办法在一个 numpy 函数中进行这种乘法?
答案 0 :(得分:1)
您似乎没有使用np.dot
进行总和缩减,而只是用于导致广播的扩展。因此,您可以简单地使用np.newaxis
/ None扩展数组以获得一个维度,并让隐式广播帮助。
因此,实现将是 -
X[...,None]*X[...,None,:]
有关广播的更多信息,具体如何添加新轴可以在this other post
中找到。