我有一个点数组列表,当我遍历它时,我想进行就地矩阵乘法,即我希望结果存储在同一矩阵中。
代码本质上是:
for p in p_list:
# R is a 3x3 matrix
p[:,:] = np.matmul(R,p)
此代码未显示任何错误,但结果不正确,就好像在数组中执行乘法并按计算方式进行替换一样,因此它创建了错误的输出矩阵。删除[:,:]可得出正确的乘法。
1)为什么会发生这种情况? 2)我使用[:,:]的主要原因是确保将结果存储回列表p_list中。是否有正确的方法(不使用中间变量)?
答案 0 :(得分:1)
如果我的理解正确,您可以使用enumerate
浏览p_list
的索引和值,并将matmul
的结果分配给给定的索引。
例如,如果您的数据如下:
>>> R
array([[0.00169934, 0.66346914, 0.07109019],
[0.28354322, 0.45933175, 0.55396787],
[0.22061139, 0.18207232, 0.51669746]])
>>> p_list
array([[1., 2., 3.],
[4., 5., 6.]])
那么您可以做:
for i,p in enumerate(p_list):
# R is a 3x3 matrix
p_list[i] = np.matmul(R,p)
以及您产生的p_list
:
>>> p_list
array([[1.54190819, 2.86411033, 2.13484839],
[3.7506842 , 6.75463886, 4.89299187]])
[EDIT] 基于@domochevski的评论,可以通过列表理解更轻松地实现此方法:
np.array([np.matmul(R,p) for p in p_list])
或者,您可以使用np.apply_along_axis
,并应用自定义函数以将matmul(R,x)
返回到每一行:
def my_matmul(x):
return np.matmul(R,x)
p_list = np.apply_along_axis(my_matmul, 1, p_list)
哪个返回相同:
>>> p_list
array([[1.54190819, 2.86411033, 2.13484839],
[3.7506842 , 6.75463886, 4.89299187]])
答案 1 :(得分:1)
matmul
接受out参数
如果p_list
是ndarray
,形状为N, 3
,则可以在一个matmul
中实现整个乘法:
np.matmul(p_list, R.T, out=p_list)
答案 2 :(得分:0)
以@sacul为例:
In [59]: R.shape
Out[59]: (3, 3)
In [60]: p_list.shape
Out[60]: (2, 3)
In [58]: np.array([np.matmul(R,p) for p in p_list])
Out[58]:
array([[1.54190819, 2.86411033, 2.13484841],
[3.7506842 , 6.75463885, 4.89299192]])
einsum
生成相同(新)的数组,而没有外部循环:
In [61]: np.einsum('ij,kj->ki',R,p_list)
Out[61]:
array([[1.54190819, 2.86411033, 2.13484841],
[3.7506842 , 6.75463885, 4.89299192]])
与ufunc
一样,它接受out
参数:
In [63]: np.einsum('ij,kj->ki',R,p_list, out=p_list)
Out[63]:
array([[1.54190819, 2.86411033, 2.13484841],
[3.7506842 , 6.75463885, 4.89299192]])
In [64]: p_list
Out[64]:
array([[1.54190819, 2.86411033, 2.13484841],
[3.7506842 , 6.75463885, 4.89299192]])
我确定它使用了中间缓冲区,但是应该比逐行迭代要快。使用out
比简单地让它返回一个新数组要慢一些。
通过调整尺寸,matmul
可以在一次调用中执行整个计算(关键是将R
的最后一个暗号与p_list
的第二个到最后一个配对(修改))。
In [84]: (R@p_list[:,:,None])[:,:,0]
Out[84]:
array([[1.54190819, 2.86411033, 2.13484841],
[3.7506842 , 6.75463885, 4.89299192]])