寻找pytorch矩阵乘法的有效实现以防止占用大量内存

时间:2019-06-07 03:59:25

标签: matrix pytorch broadcasting

pytorch中的

torch.matmul具有广播功能,可能会占用过多内存。我正在寻找有效的实现方式,以防止过度使用内存。

例如,输入张量的大小为 adj.size()==[1,3000,3000] s.size()==torch.Size([235, 3000, 10]) s.transpose(1, 2).size()==torch.Size([235, 10, 3000]) 任务是计算

link_loss = adj - torch.matmul(s, s.transpose(1, 2)) #
link_loss = torch.norm(link_loss, p=2)

原始代码位于割炬扩展软件包torch_geometric中,位于函数dense_diff_pool的定义中。 torch.matmul(s, s.transpose(1,2))将消耗过多的内存(我的计算机只有2GB的内存空间),从而引发错误:

  

回溯(最近通话最近一次):

     

文件“”,第1行,在       torch.matmul(s,s.transpose(1,2))

     

RuntimeError:$火炬:内存不足:您试图分配7GB。   购买新的RAM!在.. \ aten \ src \ TH \ THGeneral.cpp:201

软件包作者的原始代码包含大于7GB的torch.matmul(s, s.transpose(1, 2)).size()==[235,3000,3000]

我的尝试是我尝试使用for迭代

batch_size=235
link_loss=torch.sqrt(torch.stack([torch.norm(adj - torch.matmul(s[i], s[i].transpose(0, 1)), p=2)**2 for i in range(batch_size)]).sum(dim=0))

众所周知,此for循环比使用广播或其他pytorch内置函数要慢。 问题: 有没有比使用[... for ...]更好的实现方式呢? 我是学习pytorch的新手。谢谢。

0 个答案:

没有答案