torch.matmul
具有广播功能,可能会占用过多内存。我正在寻找有效的实现方式,以防止过度使用内存。
例如,输入张量的大小为
adj.size()==[1,3000,3000]
s.size()==torch.Size([235, 3000, 10])
s.transpose(1, 2).size()==torch.Size([235, 10, 3000])
任务是计算
link_loss = adj - torch.matmul(s, s.transpose(1, 2)) #
link_loss = torch.norm(link_loss, p=2)
原始代码位于割炬扩展软件包torch_geometric
中,位于函数dense_diff_pool的定义中。
torch.matmul(s, s.transpose(1,2))
将消耗过多的内存(我的计算机只有2GB的内存空间),从而引发错误:
回溯(最近通话最近一次):
文件“”,第1行,在 torch.matmul(s,s.transpose(1,2))
RuntimeError:$火炬:内存不足:您试图分配7GB。 购买新的RAM!在.. \ aten \ src \ TH \ THGeneral.cpp:201
软件包作者的原始代码包含大于7GB的torch.matmul(s, s.transpose(1, 2)).size()==[235,3000,3000]
。
我的尝试是我尝试使用for
迭代
batch_size=235
link_loss=torch.sqrt(torch.stack([torch.norm(adj - torch.matmul(s[i], s[i].transpose(0, 1)), p=2)**2 for i in range(batch_size)]).sum(dim=0))
众所周知,此for
循环比使用广播或其他pytorch内置函数要慢。
问题:
有没有比使用[... for ...]更好的实现方式呢?
我是学习pytorch的新手。谢谢。