如何将PyTorch的torch.inverse()函数应用于批处理中的每个样本?

时间:2017-10-05 21:36:45

标签: python pytorch

这似乎是一个基本问题,但我无法解决这个问题。

在我的神经网络的前向传递中,我有一个8x3x3形状的输出张量,其中8是我的批量大小。我们可以假设每个3x3张量是非奇异矩阵。我需要找到这些矩阵的逆。 PyTorch inverse()函数仅适用于方形矩阵。由于我现在有8x3x3,如何以可区分的方式将此函数应用于批处理中的每个矩阵?

如果我迭代遍历样本并将反转追加到python列表,然后我将其转换为PyTorch张量,那么在backprop期间它是否会出现问题? (我问,因为将PyTorch张量转换为numpy以执行某些操作,然后回到张量,在这种操作的backprop期间不会计算渐变)

当我尝试做类似的事情时,我也会收到以下错误。

a = torch.arange(0,8).view(-1,2,2)
b = [m.inverse() for m in a]
c = torch.FloatTensor(b)
  

TypeError:' torch.FloatTensor' object不支持索引

2 个答案:

答案 0 :(得分:5)

编辑:

从Pytorch 1.0版开始,torch.inverse现在支持批量张量。参见here。因此,您只需使用内置功能torch.inverse

旧答案

有计划很快实现批量逆运算。有关讨论,请参见例如issue 7500issue 9102。但是,在撰写本文时,当前的稳定版本(0.4.1),没有批量逆运算可用。

话虽如此,最近添加了对torch.gesv的批量支持。可以(ab)将其用于定义您自己的成批的逆运算,如下所示:

def b_inv(b_mat):
    eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
    b_inv, _ = torch.gesv(eye, b_mat)
    return b_inv

我发现,在GPU上运行时,这比for循环可提高速度。

答案 1 :(得分:0)

您可以使用torch.functional.unbind()拆分张量,对结果的每个元素应用反转,然后叠加回来:

a = torch.arange(0,8).view(-1,2,2)
b = [t.inverse() for t in torch.functional.unbind(a)]
c = torch.functional.stack(b)