Kronecker乘法仅在特定轴上

时间:2019-08-14 12:09:13

标签: python numpy

如何针对任意形状的A简化和扩展以下代码?

import numpy as np
A  = np.random.random([10,12,13,5,5])
B  = np.zeros([10,12,13,10,10])
s2 = np.array([[0,1],[-1,0]])

for i in range(10):
    for j in range(12):
        for k in range(13):
            B[i,j,k,:,:] = np.kron(A[i,j,k,:,:],s2)

我知道使用np.einsum是可能的,但在那里我还必须明确给出形状。

2 个答案:

答案 0 :(得分:2)

必须为最后两个轴计算输出形状-

out_shp = A.shape[:-2] + tuple(A.shape[-2:]*np.array(s2.shape))

然后可以使用einsum或明暗扩展名-

B_out = (A[...,:,None,:,None]*s2[:,None]).reshape(out_shp)

B_out = np.einsum('ijklm,no->ijklnmo',A,s2).reshape(out_shp)

einsum可以更广泛地用于带有省略号...的普通暗淡-

np.einsum('...lm,no->...lnmo',A,s2).reshape(out_shp)

扩展为通用昏暗

我们可以泛化为通用的暗角,这些暗角将接受通过一些整形工作-

进行克罗内克乘法的轴
def kron_along_axes(a, b, axis):
    # Extend a to the extent of the broadcasted o/p shape
    ae = a.reshape(np.insert(a.shape,np.array(axis)+1,1))

    # Extend b to the extent of the broadcasted o/p shape
    d = np.ones(a.ndim,dtype=int)
    np.put(d,axis,b.shape)
    be = b.reshape(np.insert(d,np.array(axis),1))

    # Get o/p and reshape back to a's dims
    out = ae*be

    out_shp = np.array(a.shape)
    out_shp[list(axis)] *= b.shape
    return out.reshape(out_shp)

因此,要解决我们的问题,应该是-

B = kron_along_axes(A, s2, axis=(3,4))

使用numpy.kron

如果您希望以更慢的速度来寻求优雅和可以,我们也可以将内置np.kron与某些轴排列一起使用-

def kron_along_axes(a, b, axis):
    new_order = list(np.setdiff1d(range(a.ndim),axis)) + list(axis)
    return np.kron(a.transpose(new_order),b).transpose(new_order)

答案 1 :(得分:0)

flattened_A = A.reshape([-1, A.shape[-2], A.shape[-1]])
flattened_kron_product = np.kron(flattened_A, s2)
dims = list(A.shape[:-2]) + [flattened_kron_product.shape[-2], flattened_kron_product.shape[-1]]
result = flattened_kron_product.reshape(dims)

result中减去B将导致填零。矩阵。