我是pytorch的新手。我有3D张量(32,10,64),我想要2D张量(32,64)。
我尝试了view()
,并在传递到线性层squeeze()
后使用,它将其转换为(32,10)。
答案 0 :(得分:2)
尝试一下
t = torch.rand(32, 10, 64).permute(0, 2, 1)[:, :, -1]
或者Shai指出,您也可以
t = torch.rand(32, 10, 64)[:, -1, :]
print(t.size()) # torch.Size([32, 64])