我正在尝试在张量之间进行详尽的级联。因此,例如 我有张量:
a = torch.randn(3, 512)
我想串联
concat(t1,t1),concat(t1,t2),concat(t1,t3),concat(t2,t1),concat(t2,t2)....作为一个幼稚的解决方案,
我使用了for
循环:
ans = []
result = []
split = torch.split(a, [1, 1, 1], dim=0)
for i in range(len(split)):
ans.append(split[i])
for t1 in ans:
for t2 in ans:
result.append(torch.cat((t1,t2), dim=1))
问题在于每个纪元花费很长时间并且代码很慢。 我尝试在PyTorch: How to implement attention for graph attention layer上发布有问题的解决方案,但这会导致内存错误。
t1 = a.repeat(1, a.shape[0]).view(a.shape[0] * a.shape[0], -1)
t2 = a.repeat(a.shape[0], 1)
result.append(torch.cat((t1, t2), dim=1))
我敢肯定有一种更快的方法,但是我无法弄清楚。