Einsum椭圆形带来极致的灵活性

时间:2019-04-06 12:54:58

标签: python numpy ellipsis numpy-einsum

我有一个关于einsum省略号的问题,我想肯定会在StackExchange上某个地方,但是我似乎找不到。

基本上,我有一些代码使用numpy的einsum进行许多矩阵和向量的压缩。输入通常是一些参数,然后将这些参数用于创建向量和矩阵。该代码运行良好,但是现在我想对其进行概括,以便可以在一定范围内扫描输入参数。最好的办法是使它们成为向量,并修改我的einsum表达式,以使它们接受简单携带的任意数量的附加维。这个问题是问这是否可能,以及如何做到。


因此,在我看来,此问题归结为以下原因。假设我有一个einsum表达式,该表达式创建后会执行某种矩阵乘法,例如

c = np.einsum('ij,jk->ik', a, b)

现在我想向a和b都添加任意数量的索引,并简单地将它们作为额外的索引添加到最终矩阵中,例如

c = np.einsum('ijabc,jkde->ikabcde', a, b)

现在,当您仅对a或b中的一个进行此操作时,您可以通过省略号轻松完成此操作

c = np.einsum('ij...,jk->ik...', a, b)

所以我的问题是,您是否可以以某种方式在einsum中使用多个椭圆形,例如

c = np.einsum('ij...,jk...->ik...', a, b)

这当然会引发错误,但是希望,这些示例清楚地表明了我的意思。

einsum是否支持这种“多省略号”表示法?还是有其他方法可以实现而不循环?

我的猜测是没有这样的方法,因为人们将不得不告诉einsum以什么顺序放置剩余的索引,即人们将不得不以某种方式标记椭圆。

1 个答案:

答案 0 :(得分:1)

因为没有要对齐的轴,所以我们可以简单地使用tensordot,使不参与总和减少的轴通过附加的rollaxis进行“扩展”,像这样-

np.rollaxis(np.tensordot(a,b,axes=(1,0)),a.ndim-1,1)

如果您想使用einsum,我们可以将其重塑为3D,使它们的最后一个轴为合并的轴(第三个轴向前合并为一个),然后继续进行{ {1}}并最终重塑为它们的einsum形状,这些形状在输出中分散开来,像这样-

ndim-1

我们还可以生成相应的einsum字符串表示法本身,从而跳过所有数组操作,从而专注于字符串操作本身以获得类似的内容-

shp_a = a.shape
shp_b = b.shape
shp_a[:1] + shp_a[2:]
out_shp = shp_a[:1] + (shp_b[1],) + shp_a[2:] + shp_b[2:]

a3D = a.reshape(shp_a[:2]+(-1,))
b3D = b.reshape(shp_b[:2]+(-1,))
out = np.einsum('ijk,jlm->ilkm',a3D,b3D).reshape(out_shp)

几乎没有示例案例可以展示其用法-

import string

def einsum_spreadout(a,b,a_axes,b_axes,a_spread_axis,b_spread_axis):
    from numpy.core import numerictypes as nt

    if isinstance(a_axes, (int, nt.integer)):
        a_axes = (a_axes,)

    if isinstance(b_axes, (int, nt.integer)):
        b_axes = (b_axes,)

    s = string.ascii_letters

    a_str = s[:a.ndim]
    b_str = s[a.ndim:a.ndim+b.ndim]

    b_str_ar = np.frombuffer(b_str,dtype='S1').copy()
    for (i,j) in zip(a_axes,b_axes):
        b_str_ar[j] = a_str[i]
    b_str = ''.join(b_str_ar)    

    out_str = a_str[:a_spread_axis] + b_str[:b_spread_axis]
    out_str += a_str[a_spread_axis:] + b_str[b_spread_axis:]

    out_str_ar = np.frombuffer(out_str,dtype='S1').copy()
    out_str = ''.join(out_str_ar[~np.isin(out_str_ar,np.take(b_str_ar,b_axes))])
    einsum_str = a_str+','+b_str+'->'+out_str

    return np.einsum(einsum_str,a,b)