计算一个张量与另一个张量的所有滚动之间的成对矩阵乘积的有效方法

时间:2019-12-14 11:45:40

标签: numpy tensorflow tensor

假设我们有两个张量:

形状为(d,m,n)的张量A

张量B,其形状为(d,n,l)。

如果我们想获得A和B最右边矩阵的成对矩阵乘积,我想我们可以使用np.einsum('dmn,... nl-> d ... ml',A, B)大小为(d,d,m,l)。但是,我想得到不是所有对的成对乘积。

导入参数k,1 <= k <= d,我想获得以下成对矩阵乘积:

来自

A(0,...)@ B(0,...)

A(0,...)@ B(k-1,...) ;

来自

A(1,...)@ B(1,...)

A(1,...)@ B(k,...) ;

... ;

来自

A(d-2,...)@ B(d-2,...),

A(d-2,...)@ B(d-1,...)

到 A(d-2,...)@ B(k-3,...) ;

来自

A(d-1,...)@ B(d-1,...)

A(d-1,...)@ B(k-2,...)

请注意,这里我们使用滚动方式来处理张量B。(例如numpy.roll)。

最后,我们实际上得到了一个形状为(d,k,m,l)的张量。

最有效的方法是什么。

我知道几种方法:

  1. 首先获取np.einsum('dmn,... nl-> d ... ml',A,B),然后使用掩码提取(d,k)对。

  2. 瓷砖B,然后以某种方式使用einsum。

但是我认为存在更好的方法。

1 个答案:

答案 0 :(得分:1)

我怀疑您会比for循环做得更好。例如,下面是使用einsum和stride_tricks而不是double for循环的矢量化版本:

enter image description here

代码:

from simple_benchmark import BenchmarkBuilder, MultiArgument
import numpy as np
from numpy.lib.stride_tricks import as_strided
B = BenchmarkBuilder()

@B.add_function()
def loopy(A,B,k): 
    d,m,n = A.shape                                   
    l = B.shape[-1]                     
    out = np.empty((d,k,m,l),int)                      
    for i in range(d):                         
        for j in range(k):                     
            out[i,j] = A[i]@B[(i+j)%d]                      
    return out                     

@B.add_function()
def vectory(A,B,k):                                            
    d,m,n = A.shape                                            
    l = B.shape[-1]                                            
    BB = np.concatenate([B,B[:k-1]],0)                         
    BB = as_strided(BB,(d,k,n,l),np.repeat(BB.strides,(2,1,1)))
    return np.einsum("ikl,ijln->ijkn",A,BB)                    

@B.add_arguments('d x k x m x n x l')
def argument_provider():
    for exp in range(10):
        d,k,m,n,l = (np.r_[1.6,1.5,1.5,1.5,1.5]**exp*(4,2,2,2,2)).astype(int)
        print(d,k,m,n,l)
        A = np.random.randint(0,10,(d,m,n))                            
        B = np.random.randint(0,10,(d,n,l))
        yield k*d*m*n*l,MultiArgument([A,B,k])

r = B.run()
r.plot()

import pylab
pylab.savefig('diagwa.png')