我有一个火炬张量
span_end = tensor([[[13]]])
我执行以下操作
span_end = span_end.view(1).squeeze().data.numpy()
print(type(span_end))
print(span_end.shape)
这给了我以下输出
<class 'numpy.ndarray'>
()
然后稍后当我尝试访问0th
的{{1}}元素时,我得到span_end
,因为形状以某种方式为null。我在这里做什么错了?
答案 0 :(得分:2)
tensor.squeeze()
将删除大小为1的所有尺寸,因此在这种情况下所有尺寸都将导致张量没有尺寸。
删除该语句即可。
import torch
span_end = torch.tensor([[[13]]])
span_end = span_end.view(1).numpy()
print(type(span_end))
print(span_end.shape)
print(span_end[0])
输出:
<class 'numpy.ndarray'>
(1,)
13