我有一个分割图像,其大小为[1,1,256,256]
。该图像是二进制分割的图像。我想对其进行热编码以获得大小为[1,2,256,256]
的图像。
我尝试了torch.nn.functional.one_hot(img, 2)
。但这给了我一张[1,256,256,2]
大小的图片。如何获得所需的张量?
答案 0 :(得分:1)
尝试使用transpose()
:
img_one_hot = torch.nn.functional.one_hot(img, 2).transpose(1, 4).squeeze(-1)
transpose(1, 4)
-交换第1维和第4维,返回[1, 2, 256, 256, 1]
形状的张量,squeeze(-1)
除去最后的暗淡,得到[1 , 2, 256, 256]
形状的张量。