Matlab中的收缩张量

时间:2019-04-28 16:39:48

标签: matlab tensor

我正在寻找一种在Matlab中收缩张量的两个索引的方法。

说我有一个张量为[17,10,17,12]的张量我正在寻找一个函数,该函数在具有相同索引的第一和第三个维度上求和,并留下一个尺寸为[10,12]的矩阵(类似到二维轨迹。

我目前正在研究张量网络,主要使用“置换”和“重塑”功能。如果一个人收缩多个张量并且从一开始就不小心,则可能会以一个索引想要以[i,j,i,k]形式的一个张量收缩来结束索引。

当然可以回缩张量,以免发生这种情况,但是我仍然对更可靠的解决方案感兴趣。

编辑:

以下方面的作用

A = rand(17,10,17,12);
A_contracted = zeros(10,12);
for i = [1:10]
    for j = [1:12]
        for k = [1:17]
            A_contracted(i,j) = A_contracted(i,j) + A(k,i,k,j);
        end
    end

end

4 个答案:

答案 0 :(得分:2)

这是一种方法:

A_contracted = permute(sum( ...
   A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);

以上方法使用了implicit expansion,并且可以在sum中同时沿多个维度进行操作,这是Matlab的最新功能。对于较旧的Matlab版本,

A_contracted = permute(sum(sum( ...
   A.*bsxfun(@eq, (1:size(A,1)).', reshape(1:size(A,3), 1, 1, [])),1),3), [2 4 1 3]);

答案 1 :(得分:2)

[我觉得我开始听起来像是破唱片了...]

您应该始终首先将代码实现为循环,然后尝试使用permutereshape进行优化。但是请注意,permute需要复制数据,因此倾向于增加而不是减少工作量。最新版本的MATLAB不再因循环而变慢,因此复制数据不再总是一种有用的技巧来加快处理速度。

例如,问题中的循环可以简化为:

A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
    A_contracted = A_contracted + squeeze(A(k,:,k,:));
end

(我也将其推广为任意大小)。

Luis' answer相比,我看到矢量化方法赢得了小数组的胜利,例如OP中的数组(17x10x17x12)的分辨率为0.09 ms vs 0.19 ms。但是,由于周围时间很少,因此不值得付出努力。但是,对于更大的阵列(我尝试了17x100x17x120),我看到循环方法获得了1.3毫秒和2.6毫秒的时间。

数据越多,仅使用简单的旧循环的优势就越大。 170x100x170x120时为0.04 s和0.45 s。


测试代码:

A = rand(17,100,17,120);
assert(all(method2(A)==method1(A),'all'))
timeit(@()method1(A))
timeit(@()method2(A))

function A_contracted = method1(A)
A_contracted = permute(sum( ...
   A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);
end

function A_contracted = method2(A)
A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
    A_contracted = A_contracted + squeeze(A(k,:,k,:));
end
end

答案 2 :(得分:1)

我的教授提出了另一种解决方法(以下用method3表示),涉及重塑和矩阵乘法。

  1. 采用收缩索引大小的单位矩阵
  2. 将其重塑为矢量
  3. 重塑您要相应收缩的张量
  4. 将向量和张量相乘
  5. 重塑收缩张量

Luis's(方法1)和Cris's答案(方法2)相比的示例代码:

A = rand(17,10,17,10);

timeit(@()method1(A))
timeit(@()method2(A))
timeit(@()method3(A))

function A_contracted = method1(A)
A_contracted = permute(sum( ...
   A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);
end


function A_contracted = method2(A)
A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
    A_contracted = A_contracted + squeeze(A(k,:,k,:));
end
end


function A_contracted = method3(A)
sa_1 = size(A,1);
Unity = eye(size(A, 1));
Unity = reshape(Unity, [1,sa_1*sa_1]);
A1 = permute(A, [1,3,2,4]);
A2 = reshape(A1, [sa_1*sa_1, size(A1, 3)* size(A1,4)]);
UnA = Unity*A2;
A_contracted = reshape(UnA, [size(A1,3), size(A1,4)]);
end

方法3在小尺寸上比方法1和方法2都占优势,并且在较大尺寸上也优于方法1,但是对于较大尺寸的循环则被方法1击败。

method3具有(有点个人的)优点,即在我的物理课程中应用程序更直观,因为收缩实际上不是张量本身,而是相对于度量。 method3可能很容易修改以合并此功能。

答案 3 :(得分:0)

非常简单

squeeze(sum(sum(a,3),1))

sum(a,n)对数组的第n个维度求和,而squeeze除去所有单例维度