3D数组的乘法和切片

时间:2019-03-30 15:13:15

标签: python numpy matrix-multiplication

我有一个大小为5 x 98 x 3的矩阵。我想找到每个98 x 3块的转置,并将其与自身相乘以找到标准偏差。 因此,我希望我的最终答案是5 x 3 x 3。 使用numpy进行此操作的有效方法是什么。

我目前可以使用以下代码执行此操作:

MU.shape[0] = 5
rows = 98
SIGMA = []
    for i in np.arange(MU.shape[0]):
        SIGMA.append([])
        SIGMA[i] = np.matmul(np.transpose(diff[i]),diff[i])
    SIGMA = np.array(SIGMA)
    SIGMA = SIGMA/rows

这里的差异大小为5 x 98 x 3。

2 个答案:

答案 0 :(得分:1)

使用np.einsum来减少最后一个轴的数量-

SIGMA = np.einsum('ijk,ijl->ikl',diff,diff)
SIGMA = SIGMA/rows

使用optimize中带有True值的np.einsum标志来利用BLAS

我们还可以使用np.matmul来获得sum-reductions-

SIGMA = np.matmul(diff.swapaxes(1,2),diff)

答案 1 :(得分:1)

您可以使用此:

my_result = arr1.swapaxes(1,2) @ arr1

进行测试:

import numpy as np

NINETY_EIGHT = 10
arr1 = np.arange(5*NINETY_EIGHT*3).reshape(5,NINETY_EIGHT,3)

my_result = arr1.swapaxes(1,2) @ arr1
print (my_result.shape)

输出:

(5, 3, 3)