如何在pytorch中将3D张量转换为2D张量?

时间:2018-11-26 02:32:05

标签: python deep-learning pytorch

我是pytorch的新手。我有3D张量(32,10,64),我想要2D张量(32,64)。 我尝试了view(),并在传递到线性层squeeze()后使用,它将其转换为(32,10)。

1 个答案:

答案 0 :(得分:2)

尝试一下

t = torch.rand(32, 10, 64).permute(0, 2, 1)[:, :, -1]

或者Shai指出,您也可以

t = torch.rand(32, 10, 64)[:, -1, :]


print(t.size()) # torch.Size([32, 64])