一种使用pytorch热编码分割的图像

时间:2019-08-11 09:25:14

标签: python encoding pytorch

我有一个分割图像,其大小为[1,1,256,256]。该图像是二进制分割的图像。我想对其进行热编码以获得大小为[1,2,256,256]的图像。 我尝试了torch.nn.functional.one_hot(img, 2)。但这给了我一张[1,256,256,2]大小的图片。如何获得所需的张量?

1 个答案:

答案 0 :(得分:1)

尝试使用transpose()

img_one_hot = torch.nn.functional.one_hot(img, 2).transpose(1, 4).squeeze(-1)

transpose(1, 4)-交换第1维和第4维,返回[1, 2, 256, 256, 1]形状的张量,squeeze(-1)除去最后的暗淡,得到[1 , 2, 256, 256]形状的张量。