张量之间的穷举级联

时间:2019-01-14 15:51:44

标签: python deep-learning pytorch attention-model

我正在尝试在张量之间进行详尽的级联。因此,例如 我有张量:

a = torch.randn(3, 512)

我想串联

concat(t1,t1),concat(t1,t2),concat(t1,t3),concat(t2,t1),concat(t2,t2)....

作为一个幼稚的解决方案, 我使用了for循环:

ans = []
result = []
split = torch.split(a, [1, 1, 1], dim=0)

for i in range(len(split)):
    ans.append(split[i])

for t1 in ans:
    for t2 in ans:
        result.append(torch.cat((t1,t2), dim=1))

问题在于每个纪元花费很长时间并且代码很慢。 我尝试在PyTorch: How to implement attention for graph attention layer上发布有问题的解决方案,但这会导致内存错误。

t1 = a.repeat(1, a.shape[0]).view(a.shape[0] * a.shape[0], -1)
t2 = a.repeat(a.shape[0], 1)
result.append(torch.cat((t1, t2), dim=1))

我敢肯定有一种更快的方法,但是我无法弄清楚。

0 个答案:

没有答案