我在Pytorch中有一个大小为(24, 2, 224, 224)
的张量。
这是执行二进制分段的CNN的输出。在2个矩阵的每个单元格中存储该像素为前景或背景的概率:[n][0][h][w] + [n][1][h][w] = 1
用于每个坐标
我想将它重塑为一个大小为(24, 1, 224, 224)
的张量。根据概率较高的矩阵,新图层中的值应为0
或1
。
我该怎么做?我应该使用哪种功能?
答案 0 :(得分:1)
使用torch.argmax()
(对于PyTorch +0.4):
prediction = torch.argmax(tensor, dim=1) # with 'dim' the considered dimension
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)
如果PyTorch版本低于0.4.0,可以使用tensor.max()
返回最大值及其索引(但不能在索引值上区分):
_, prediction = tensor.max(dim=1)
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)