我正在尝试使用Numpy在Python中高效地实现n模式张量矩阵乘积(由Kolda和Bader定义:https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6702.pdf)。操作有效地降到了(对于矩阵U,张量X和轴/模式k):
通过折叠所有其他轴,从X提取沿k轴的所有矢量。
使用标准矩阵乘法将这些向量乘以U。
使用相同的形状将向量再次插入输出张量,除了X.shape [k]等于U.shape [0](最初,X.shape [k]必须等于矩阵相乘的结果等于U.shape [1]。
一段时间以来,我一直在使用显式实现来单独执行所有这些步骤:
转置张量以将轴k移到前面(在我的完整代码中,在k == X.ndim-1的情况下,我添加了一个例外,在这种情况下,将其保留在那里并转置所有以后的操作会更快,或者至少在我的应用程序中,但这与此处无关)。
重塑张量以折叠所有其他轴。
计算矩阵乘法。
重塑张量以重建所有其他轴。
将张量转回原始顺序。
我认为此实现会创建许多不必要的(大)数组,因此一旦发现np.einsum,我认为这会大大加快速度。但是,使用下面的代码,我得到的结果更糟:
import numpy as np
from time import time
def mode_k_product(U, X, mode):
transposition_order = list(range(X.ndim))
transposition_order[mode] = 0
transposition_order[0] = mode
Y = np.transpose(X, transposition_order)
transposed_ranks = list(Y.shape)
Y = np.reshape(Y, (Y.shape[0], -1))
Y = U @ Y
transposed_ranks[0] = Y.shape[0]
Y = np.reshape(Y, transposed_ranks)
Y = np.transpose(Y, transposition_order)
return Y
def einsum_product(U, X, mode):
axes1 = list(range(X.ndim))
axes1[mode] = X.ndim + 1
axes2 = list(range(X.ndim))
axes2[mode] = X.ndim
return np.einsum(U, [X.ndim, X.ndim + 1], X, axes1, axes2, optimize=True)
def test_correctness():
A = np.random.rand(3, 4, 5)
for i in range(3):
B = np.random.rand(6, A.shape[i])
X = mode_k_product(B, A, i)
Y = einsum_product(B, A, i)
print(np.allclose(X, Y))
def test_time(method, amount):
U = np.random.rand(256, 512)
X = np.random.rand(512, 512, 256)
start = time()
for i in range(amount):
method(U, X, 1)
return (time() - start)/amount
def test_times():
print("Explicit:", test_time(mode_k_product, 10))
print("Einsum:", test_time(einsum_product, 10))
test_correctness()
test_times()
我的时间:
明确:3.9450525522232054
Einsum:15.873924326896667
这是正常现象还是我做错了什么?我知道在某些情况下存储中间结果会降低复杂度(例如,链矩阵乘法),但是在这种情况下,我无法想到任何重复的计算。矩阵乘法是否经过优化,从而消除了不转置的好处(从技术上讲它的复杂度较低)?
答案 0 :(得分:1)
我更熟悉使用einsum
的下标样式,因此计算出了这些等效项:
In [194]: np.allclose(np.einsum('ij,jkl->ikl',B0,A), einsum_product(B0,A,0))
Out[194]: True
In [195]: np.allclose(np.einsum('ij,kjl->kil',B1,A), einsum_product(B1,A,1))
Out[195]: True
In [196]: np.allclose(np.einsum('ij,klj->kli',B2,A), einsum_product(B2,A,2))
Out[196]: True
使用mode
参数,您在einsum_product
中的方法可能是最好的。但是,这些等效项可以帮助我更好地形象化计算结果,并且可能对其他人有所帮助。
时间基本上应该相同。 einsum_product
中有一个额外的设置时间,应该在更大的尺寸中消失。
答案 1 :(得分:0)
在更新了Numpy之后,无论是否使用多线程,Einsum都比显式方法稍慢(请参阅我的问题的注释)。