Pytorch张量到numpy将“()”作为形状

时间:2019-07-17 14:54:54

标签: python numpy pytorch

我有一个火炬张量

span_end = tensor([[[13]]])

我执行以下操作

span_end = span_end.view(1).squeeze().data.numpy()
            print(type(span_end))
            print(span_end.shape)

这给了我以下输出

<class 'numpy.ndarray'>
()

然后稍后当我尝试访问0th的{​​{1}}元素时,我得到span_end,因为形状以某种方式为null。我在这里做什么错了?

1 个答案:

答案 0 :(得分:2)

tensor.squeeze()将删除大小为1的所有尺寸,因此在这种情况下所有尺寸都将导致张量没有尺寸。

删除该语句即可。

import torch
span_end = torch.tensor([[[13]]])
span_end = span_end.view(1).numpy()
print(type(span_end))
print(span_end.shape)
print(span_end[0])

输出:

<class 'numpy.ndarray'>
(1,)
13