如何在不显示渐变的情况下打印张量

时间:2018-08-13 18:33:36

标签: python pytorch

如果我做这样的事情:

tmp = torch.ones(3, 2, 2, requires_grad=True)
out = tmp ** 2
print("\n{}".format(out))

我得到的输出是

tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]], grad_fn=<PowBackward0>)

我只想打印出值,而不是grad_fn部分。

但是,做

print("\n{}".format(out[0]))

导致:

tensor([[1., 1.],
        [1., 1.]], grad_fn=<SelectBackward>)

我知道的唯一方法是out.detach(),或者还有另一种/更好的方法?为了澄清,我很高兴计算了梯度。我只想显示没有附加数据的向量值。

1 个答案:

答案 0 :(得分:3)

使用data应该为您完成这项工作:

tmp = torch.ones(3, 2, 2, requires_grad=True)
out = tmp ** 2
print("\n{}".format(out.data))

输出:

tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])