numpy:在两个2d数组的一个公共轴上广播乘法

时间:2016-11-17 00:38:27

标签: python arrays numpy optimization

我正在寻找一种方法,以元素方式分别将两个2d形状(a,b)和(b,c)数组相乘。在'b'轴上,两个阵列有共同点。

例如,我想要广播(矢量化)的一个例子是:

import numpy as np    

# some dummy data
A = np.empty((2, 3))
B = np.empty((3, 4))

# naive implementation
C = np.vstack(np.kron(A[:, i], B[i, :]) for i in [0, 1, 2])

# this should give (3, 2, 4)
C.shape

有谁知道该怎么做?还有更好的方法吗?

2 个答案:

答案 0 :(得分:3)

使用不同的测试用例:

In [56]: A=np.arange(6).reshape((2,3))
In [57]: B=np.arange(12).reshape((3,4))
In [58]: np.vstack([np.kron(A[:,i],B[i,:]) for i in range(3)])
Out[58]: 
array([[ 0,  0,  0,  0,  0,  3,  6,  9],
       [ 4,  5,  6,  7, 16, 20, 24, 28],
       [16, 18, 20, 22, 40, 45, 50, 55]])

首次尝试使用'einsum,保留所有3个轴(无求和)

In [60]: np.einsum('ij,jk->ijk',A,B)
Out[60]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

相同的数字,但形状不同。

我可以在输出上重新排序轴,制作一个2x4x3,可以重新变形为8,3并进行转置。

In [64]: np.einsum('ij,jk->ikj',A,B).reshape(8,3).T
Out[64]: 
array([[ 0,  0,  0,  0,  0,  3,  6,  9],
       [ 4,  5,  6,  7, 16, 20, 24, 28],
       [16, 18, 20, 22, 40, 45, 50, 55]])

因此,通过另一次迭代,我可以摆脱转置

In [68]: np.einsum('ij,jk->jik',A,B).reshape(3,8)
Out[68]: 
array([[ 0,  0,  0,  0,  0,  3,  6,  9],
       [ 4,  5,  6,  7, 16, 20, 24, 28],
       [16, 18, 20, 22, 40, 45, 50, 55]])

我应该马上到达那里。 A是(2,3),B是(3,4),我希望(3,2,4)重新变形为(3,8)。 i = 2,j = 3,k = 4 => JIK。

另一种描述问题的方式,

a_ij * b_jk = c_jik

由于我没有使用sum的{​​{1}}部分,因此定期播放的乘法也可以使用一个或多个转置。

答案 1 :(得分:2)

归功于@hpaulj,了解AB 的定义
使用np.outernp.stack

A = np.arange(6).reshape((2, 3))
B = np.arange(12).reshape((3, 4))

np.stack([np.outer(A[:, i], B[i, :]) for i in range(A.shape[1])])

[[[ 0  0  0  0]
  [ 0  3  6  9]]

 [[ 4  5  6  7]
  [16 20 24 28]]

 [[16 18 20 22]
  [40 45 50 55]]]

并使np.einsum处于正确的形状

np.einsum('ij, jk->jik', A, B)

[[[ 0  0  0  0]
  [ 0  3  6  9]]

 [[ 4  5  6  7]
  [16 20 24 28]]

 [[16 18 20 22]
  [40 45 50 55]]]

广播和transpose

(A[:, None] * B.T).transpose(2, 0, 1)

[[[ 0  0  0  0]
  [ 0  3  6  9]]

 [[ 4  5  6  7]
  [16 20 24 28]]

 [[16 18 20 22]
  [40 45 50 55]]]

形状为(3, 2, 4)

<强> 定时
enter image description here