MATLAB Numpy元素明智乘法问题

时间:2019-01-18 20:25:48

标签: python matlab numpy

请参见下面的MATLAB代码和等效的Numpy代码。 问题:如何在Numpy中获得与MATLAB相同的D变量?

MATLAB代码

A = [1 2 3; 4 5 6; 7 8 9]

C = [100 1; 10 0.1; 1, 0.01]

C = reshape(C, 1,3,2)

D = bsxfun(@times, A, C)

D(:,:,1) =

       100    20     3
       400    50     6
       700    80     9

D(:,:,2) =

    1.0000    0.2000    0.0300
    4.0000    0.5000    0.0600
    7.0000    0.8000    0.0900

Numpy代码

A = np.array([[1,2,3],[4,5,6],[7,8,9]])

C = np.array([[[100, 1], [10, 0.1], [1, 0.01]]]) # C.shape is (1, 3, 2)

D = A * C.T

D

    array([[[100.  , 200.  , 300.  ],
            [ 40.  ,  50.  ,  60.  ],
            [  7.  ,   8.  ,   9.  ]],

           [[  1.  ,   2.  ,   3.  ],
            [  0.4 ,   0.5 ,   0.6 ],
            [  0.07,   0.08,   0.09]]])

2 个答案:

答案 0 :(得分:2)

您向C添加了一个转置,而该代码在MATLAB代码中不存在。

如果要保持完全相同的数据布局,请将 trailing 单例维度插入A中。在MATLAB中, trailing 单身是隐式的,而在numpy中, leading 单身是隐式的:

>>> D = A[...,None] * C.squeeze()

>>> D
array([[[1.e+02, 1.e+00],
        [2.e+01, 2.e-01],
        [3.e+00, 3.e-02]],

       [[4.e+02, 4.e+00],
        [5.e+01, 5.e-01],
        [6.e+00, 6.e-02]],

       [[7.e+02, 7.e+00],
        [8.e+01, 8.e-01],
        [9.e+00, 9.e-02]]])

此处A[..., None]的形状为(3, 3, 1)C.squeeze()只是消除了多余的前导单例尺寸并使它的形状为(3,2),这些广播形成了(3, 3, 2)的形状。 MATLAB和numpy对多维数组的解释不同,这解释了为什么上面的repr对应于形状为(3,2)的三个数组,而MATLAB为您显示了形状为(3,3的两个数组。但实际上是同一数组:

>>> D[..., 0]
array([[100.,  20.,   3.],
       [400.,  50.,   6.],
       [700.,  80.,   9.]])

>>> D[..., 1]
array([[1.  , 0.2 , 0.03],
       [4.  , 0.5 , 0.06],
       [7.  , 0.8 , 0.09]])

请注意,如果您将numpy代码中的MATLAB顺序保持不变,则可能要在数组中使用fortran布局,否则,在numpy代码中次佳的位置会有“快速”轴。

答案 1 :(得分:2)

你很近。您可以通过将矩阵的转置相乘,然后使用交换轴转置最终的子矩阵来实现此目的

A = np.array([[1,2,3],[4,5,6],[7,8,9]])
C = np.array([[[100, 1], [10, 0.1], [1, 0.01]]]) # C.shape is (1, 3, 2)

D = (A.T*C.T)
D = D.swapaxes(1,2)

您也可以将这些行合并为

D = (A.T*C.T).swapaxes(1,2)

输出

array([[[1.e+02, 2.e+01, 3.e+00],
    [4.e+02, 5.e+01, 6.e+00],
    [7.e+02, 8.e+01, 9.e+00]],

   [[1.e+00, 2.e-01, 3.e-02],
    [4.e+00, 5.e-01, 6.e-02],
    [7.e+00, 8.e-01, 9.e-02]]])