对于n模式张量矩阵乘积,Einsum比显式Numpy实现要慢

时间:2019-04-19 21:04:06

标签: python numpy

我正在尝试使用Numpy在Python中高效地实现n模式张量矩阵乘积(由Kolda和Bader定义:https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6702.pdf)。操作有效地降到了(对于矩阵U,张量X和轴/模式k):

  1. 通过折叠所有其他轴,从X提取沿k轴的所有矢量。

  2. 使用标准矩阵乘法将这些向量乘以U。

  3. 使用相同的形状将向量再次插入输出张量,除了X.shape [k]等于U.shape [0](最初,X.shape [k]必须等于矩阵相乘的结果等于U.shape [1]。

一段时间以来,我一直在使用显式实现来单独执行所有这些步骤:

  1. 转置张量以将轴k移到前面(在我的完整代码中,在k == X.ndim-1的情况下,我添加了一个例外,在这种情况下,将其保留在那里并转置所有以后的操作会更快,或者至少在我的应用程序中,但这与此处无关)。

  2. 重塑张量以折叠所有其他轴。

  3. 计算矩阵乘法。

  4. 重塑张量以重建所有其他轴。

  5. 将张量转回原始顺序。

我认为此实现会创建许多不必要的(大)数组,因此一旦发现np.einsum,我认为这会大大加快速度。但是,使用下面的代码,我得到的结果更糟:

import numpy as np
from time import time

def mode_k_product(U, X, mode):
    transposition_order = list(range(X.ndim))
    transposition_order[mode] = 0
    transposition_order[0] = mode
    Y = np.transpose(X, transposition_order)
    transposed_ranks = list(Y.shape)
    Y = np.reshape(Y, (Y.shape[0], -1))
    Y = U @ Y
    transposed_ranks[0] = Y.shape[0]
    Y = np.reshape(Y, transposed_ranks)
    Y = np.transpose(Y, transposition_order)
    return Y

def einsum_product(U, X, mode):
    axes1 = list(range(X.ndim))
    axes1[mode] = X.ndim + 1
    axes2 = list(range(X.ndim))
    axes2[mode] = X.ndim
    return np.einsum(U, [X.ndim, X.ndim + 1], X, axes1, axes2, optimize=True)

def test_correctness():
    A = np.random.rand(3, 4, 5)
    for i in range(3):
        B = np.random.rand(6, A.shape[i])
        X = mode_k_product(B, A, i)
        Y = einsum_product(B, A, i)
        print(np.allclose(X, Y))

def test_time(method, amount):
    U = np.random.rand(256, 512)
    X = np.random.rand(512, 512, 256)
    start = time()
    for i in range(amount):
        method(U, X, 1)
    return (time() - start)/amount

def test_times():
    print("Explicit:", test_time(mode_k_product, 10))
    print("Einsum:", test_time(einsum_product, 10))

test_correctness()
test_times()

我的时间:

明确:3.9450525522232054

Einsum:15.873924326896667

这是正常现象还是我做错了什么?我知道在某些情况下存储中间结果会降低复杂度(例如,链矩阵乘法),但是在这种情况下,我无法想到任何重复的计算。矩阵乘法是否经过优化,从而消除了不转置的好处(从技术上讲它的复杂度较低)?

2 个答案:

答案 0 :(得分:1)

我更熟悉使用einsum的下标样式,因此计算出了这些等效项:

In [194]: np.allclose(np.einsum('ij,jkl->ikl',B0,A), einsum_product(B0,A,0))          
Out[194]: True
In [195]: np.allclose(np.einsum('ij,kjl->kil',B1,A), einsum_product(B1,A,1))          
Out[195]: True
In [196]: np.allclose(np.einsum('ij,klj->kli',B2,A), einsum_product(B2,A,2))          
Out[196]: True

使用mode参数,您在einsum_product中的方法可能是最好的。但是,这些等效项可以帮助我更好地形象化计算结果,并且可能对其他人有所帮助。

时间基本上应该相同。 einsum_product中有一个额外的设置时间,应该在更大的尺寸中消失。

答案 1 :(得分:0)

在更新了Numpy之后,无论是否使用多线程,Einsum都比显式方法稍慢(请参阅我的问题的注释)。