优化张量乘法

时间:2018-07-28 17:27:03

标签: matlab numpy dot-product numpy-einsum

我有一个正在尝试优化的实时图像处理程序,所有这些都归结为矩阵乘法。考虑一下我在初始化阶段正在计算的3个张量:

  1. A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
  2. B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
  3. C = np.arange(59 * 27).reshape([59, 27])

每帧,我都以第四张量的形式获取新数据:

  • M = np.arange(35 * 37 * 59).reshape([35, 37, 59])

当前,我正在计算D = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C),其中D是我想要的结果,它是程序的主要瓶颈。为了优化它,我尝试遵循两个方向。

首先,我尝试提出张量T,该张量是A, B, C, D的函数,可以对其进行预先计算,然后将其全部转换为D = np.tensordot(M, T, axes=..)。我没有成功。我花了很多时间,甚至有可能吗?

此外,程序本身是用MATLAB编写的。由于它没有内置的张量乘法功能(einsumtensordot等价),因此我目前正在使用tprod工具箱,并进行以下操作:

temp1 = etprod('dcb', A, 'abc', M, 'adc');
temp2 = etprod('dbc', B, 'abcd', temp1, 'adb');
D = etprod('cdb', C, 'ab', temp2, 'acd');

由于MATLAB(用于2D矩阵)中的默认点积函数要比etprod快得多,因此我将A, B, C, D重塑为2D数组的方式是,我将能够使用默认功能,没有手写的for循环。我也没有成功。

有什么想法吗?谢谢!

1 个答案:

答案 0 :(得分:1)

如果使用不同的 M 值多次执行此操作,我们可以定义

"scripts:" {
  "start": "serve-build"
}

整个操作可以分解成二进制步骤:

D0 = np.einsum('xft,fr->tpr',A, B, C)

最后的运算使用D0和M,可以编码为矩阵向量运算。在 Matlab 中是

D0=np.einsum('xtf,ytpf->xyptf',A,B)
D0=np.einsum('xyptf,fr->xyftpr',D0,C)
D=np.einsum('tprxfy,xfy->tpr',D0,M)

然后可以根据需要重新排序。 我们可以把这个顺序写成 (((A,B),C),M)

不过,使用 ((M,C),A,B) 可能会更好

D=reshape(D0.[],numel(M))*M(:);

这种操作顺序的中间数组只有 4 个索引,而不是一个有 6 个索引。如果每个操作都比单个操作快得多,这可能是一个优势。