如何将带有grad的Torch张量列表转换为张量

时间:2020-09-28 14:32:32

标签: computer-vision pytorch heatmap tensor face-alignment

我有一个名为pts的变量,其形状为[batch,ch,h,w]。这是一个热图,我想将其转换为第二坐标。目标是pts_o = heatmap_to_pts(pts),其中pts_o将是[batch,ch,2]。到目前为止,我已经编写了此函数,

def heatmap_to_pts(self, pts):  <- pts [batch, 68, 128, 128]
    
    pt_num = []
    
    for i in range(len(pts)):
        
        pt = pts[i]
        if type(pt) == torch.Tensor:

            d = torch.tensor(128)                                                   * get the   
            m = pt.view(68, -1).argmax(1)                                           * indices
            indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)  * from heatmaps
        
            pt_num.append(indices.type(torch.DoubleTensor) )   <- store the indices in a list

    b = torch.Tensor(68, 2)                   * trying to convert
    c = torch.cat(pt_num, out=b) *error*      * a list of tensors with grad
    c = c.reshape(68,2)                       * to a tensor like [batch, 68, 2]

    return c

错误显示“ cat():带有out = ...参数的函数不支持自动区分,但其中一个参数需要grad。”。无法执行操作,因为pt_num中的张量需要grad“。

如何将列表转换为张量?

1 个答案:

答案 0 :(得分:1)

错误提示

cat():带有out = ...参数的函数不支持自动微分,但是其中一个参数需要grad。

这意味着torch.cat()之类的函数out=的输出不能用作autograd引擎(执行自动微分)的输入。

原因是(在您的Python列表pt_num中的张量的requires_grad属性具有不同的值,即某些张量具有requires_grad=True,而有些张量具有{{1 }}。

在您的代码中,以下行在逻辑上很麻烦:

requires_grad=False

c = torch.cat(pt_num, out=b) 的返回值,无论是否使用torch.cat() kwarg,都是沿着所提到的维数的张量的串联。

因此,张量out=已经是c中各个张量的串联版本。使用pt_num冗余。因此,您可以简单地摆脱out=b,一切都会好起来的。

out=b