我正在寻找一种在Matlab中收缩张量的两个索引的方法。
说我有一个张量为[17,10,17,12]的张量我正在寻找一个函数,该函数在具有相同索引的第一和第三个维度上求和,并留下一个尺寸为[10,12]的矩阵(类似到二维轨迹。
我目前正在研究张量网络,主要使用“置换”和“重塑”功能。如果一个人收缩多个张量并且从一开始就不小心,则可能会以一个索引想要以[i,j,i,k]形式的一个张量收缩来结束索引。
当然可以回缩张量,以免发生这种情况,但是我仍然对更可靠的解决方案感兴趣。
编辑:
以下方面的作用
:A = rand(17,10,17,12);
A_contracted = zeros(10,12);
for i = [1:10]
for j = [1:12]
for k = [1:17]
A_contracted(i,j) = A_contracted(i,j) + A(k,i,k,j);
end
end
end
答案 0 :(得分:2)
这是一种方法:
A_contracted = permute(sum( ...
A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);
以上方法使用了implicit expansion,并且可以在sum
中同时沿多个维度进行操作,这是Matlab的最新功能。对于较旧的Matlab版本,
A_contracted = permute(sum(sum( ...
A.*bsxfun(@eq, (1:size(A,1)).', reshape(1:size(A,3), 1, 1, [])),1),3), [2 4 1 3]);
答案 1 :(得分:2)
[我觉得我开始听起来像是破唱片了...]
您应该始终首先将代码实现为循环,然后尝试使用permute
和reshape
进行优化。但是请注意,permute
需要复制数据,因此倾向于增加而不是减少工作量。最新版本的MATLAB不再因循环而变慢,因此复制数据不再总是一种有用的技巧来加快处理速度。
例如,问题中的循环可以简化为:
A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
A_contracted = A_contracted + squeeze(A(k,:,k,:));
end
(我也将其推广为任意大小)。
与Luis' answer相比,我看到矢量化方法赢得了小数组的胜利,例如OP中的数组(17x10x17x12)的分辨率为0.09 ms vs 0.19 ms。但是,由于周围时间很少,因此不值得付出努力。但是,对于更大的阵列(我尝试了17x100x17x120),我看到循环方法获得了1.3毫秒和2.6毫秒的时间。
数据越多,仅使用简单的旧循环的优势就越大。 170x100x170x120时为0.04 s和0.45 s。
测试代码:
A = rand(17,100,17,120);
assert(all(method2(A)==method1(A),'all'))
timeit(@()method1(A))
timeit(@()method2(A))
function A_contracted = method1(A)
A_contracted = permute(sum( ...
A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);
end
function A_contracted = method2(A)
A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
A_contracted = A_contracted + squeeze(A(k,:,k,:));
end
end
答案 2 :(得分:1)
我的教授提出了另一种解决方法(以下用method3表示),涉及重塑和矩阵乘法。
与Luis's(方法1)和Cris's答案(方法2)相比的示例代码:
A = rand(17,10,17,10);
timeit(@()method1(A))
timeit(@()method2(A))
timeit(@()method3(A))
function A_contracted = method1(A)
A_contracted = permute(sum( ...
A.*((1:size(A,1)).'==reshape(1:size(A,3), 1, 1, [])), [1 3]), [2 4 1 3]);
end
function A_contracted = method2(A)
A_contracted = zeros(size(A,2),size(A,4));
for k = 1:size(A,1)
A_contracted = A_contracted + squeeze(A(k,:,k,:));
end
end
function A_contracted = method3(A)
sa_1 = size(A,1);
Unity = eye(size(A, 1));
Unity = reshape(Unity, [1,sa_1*sa_1]);
A1 = permute(A, [1,3,2,4]);
A2 = reshape(A1, [sa_1*sa_1, size(A1, 3)* size(A1,4)]);
UnA = Unity*A2;
A_contracted = reshape(UnA, [size(A1,3), size(A1,4)]);
end
方法3在小尺寸上比方法1和方法2都占优势,并且在较大尺寸上也优于方法1,但是对于较大尺寸的循环则被方法1击败。
method3具有(有点个人的)优点,即在我的物理课程中应用程序更直观,因为收缩实际上不是张量本身,而是相对于度量。 method3可能很容易修改以合并此功能。
答案 3 :(得分:0)
非常简单
squeeze(sum(sum(a,3),1))
sum(a,n)
对数组的第n个维度求和,而squeeze
除去所有单例维度