多矩阵乘法

时间:2017-01-10 22:20:10

标签: python performance numpy matrix matrix-multiplication

在numpy中,我有一个N 3x3矩阵的数组。这将是我如何存储它们的一个例子(我抽象出内容):

N = 10
matrices = np.ones((N, 3, 3))

我还有一个3向量数组,这是一个例子:

vectors = np.ones((N, 3))

我似乎无法弄清楚如何通过numpy来增加它们,以便实现这样的目标:

result_vectors = []
for matrix, vector in zip(matrices, vectors):
    result_vectors.append(matrix @ vector)

result_vector的形状(在转换为数组时)为(N, 3)。 但是,由于速度的原因,列表实现是不可能的。

我已尝试过各种换位的np.dot,但最终结果并没有得到正确的形状。

1 个答案:

答案 0 :(得分:1)

使用np.einsum -

np.einsum('ijk,ik->ij',matrices,vectors)

步骤:

1)保持第一轴对齐。

2)求和 - 减少输入数组中的最后一个轴。

3)让剩余的轴(来自matrices的第二个轴)按元素倍增。