Numpy - 矩阵乘法返回ndarray,而不是sum

时间:2017-12-26 06:54:51

标签: python numpy multidimensional-array matrix-multiplication

所有,我有一个应用程序,当需要两个矩阵相乘时,需要返回一个numpy ndarray,而不是一个简单的和; e.g:

import numpy as np
x = np.array([[1, 1, 0], [0, 1, 1]])
y = np.array([[1, 0, 0, 1], [1, 0, 1, 0], [0, 0, 0, 0]])
w = x @ y
>>> array([[2, 0, 1, 1],
           [1, 0, 1, 0]])

但是,要求是返回一个ndarray(在这种情况下是......):

array([[[1,1,0], [0,0,0], [0,1,0], [1,0,0]],
       [[0,1,0], [0,0,0], [0,1,0], [0,0,0]]])

注意,可以重复矩阵乘法运算;输出将用作下一个矩阵乘法运算的ndarray的左侧矩阵,这将在第二次矩阵乘法运算之后产生更高阶的ndarray等。

有任何方法可以实现这一目标吗?我已经通过子类化np.ndarray as discussed here来查看重载__add____radd__,但大多数都出现了维度不兼容错误。

想法?

更新

解决@Divakar的回答例如,对于链式操作,添加

z = np.array([[1, 1, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]])
s1 = x[...,None] * y
s2 = s1[...,None] * z

导致不希望的输出。

我怀疑问题从s1开始,在上面的情况下返回s1.shape =(2,3,4)。它应该是(2,4,3),因为[2x3] [3x4] = [2x4],但我们并没有真正在这里求和,只返回一个长度为3的数组。

类似地,s2.shape应该是(2,3,4,3),[顺便说一句]它是,但是有不希望的输出(它不是'错误',不是什么我们正在寻找)。 详细说明,s1 * z应为[2x4] [4x3] = [2x3]矩阵。矩阵的每个元素本身都是[4x3]的ndarray,因为我们在z中有4行来乘以s1中的元素,而s1中的每个元素本身就是3个元素长(同样,我们不是在算术上添加元素) ,但返回ndarrays,扩展维度是操作的R矩阵中的行计数。

最终,期望的输出将是:

s2 = array([[[[1, 1, 0],
              [0, 0, 0],
              [0, 1, 0],
              [0, 0, 0]],

              [[1, 1, 0],
               [0, 0, 0],
               [0, 0, 0],
               [1, 0, 0]],

              [[0, 0, 0],
               [0, 0, 0],
               [0, 0, 0],
               [0, 0, 0]]],


             [[[0, 1, 0],
               [0, 0, 0],
               [0, 1, 0],
               [0, 0, 0]],

              [[0, 1, 0],
               [0, 0, 0],
               [0, 0, 0],
               [0, 0, 0]],

              [[0, 0, 0],
               [0, 0, 0],
               [0, 0, 0],
               [0, 0, 0]]]])

1 个答案:

答案 0 :(得分:4)

将它们扩展为3D并利用broadcasting -

x[:,None] * y.T

np.einsum -

np.einsum('ij,jk->ikj',x,y)

OP's comment和问题引用:

  

...可以重复矩阵乘法运算;输出会   用作下一个矩阵的ndarrays的左侧矩阵   乘法运算,这将产生更高阶的ndarray   在第二次矩阵乘法运算之后等。

看来,我们需要沿着这些方向做点什么 -

s1 = x[...,None] * y
s2 = s1[...,None] * z # and so on.

但是,在这种情况下,轴的顺序会有所不同,但似乎是将解决方案扩展到通用数量的传入2D数组的最简单方法。

在问题中的编辑之后,看起来您正在从第一个轴开始放置传入的数组以进行逐元素乘法。所以,如果我做对了,你可以交换轴来获得正确的顺序,就像这样 -

s1c = (x[...,None] * y).swapaxes(1,-1)
s2c = (s1c.swapaxes(1,-1)[...,None] * z).swapaxes(1,-1) # and so on.

如果您只对最终输出感兴趣,请仅在最后阶段交换轴,并跳过中间轴。