重塑Pytorch张量

时间:2018-05-17 12:38:56

标签: python arrays numpy matrix pytorch

我在Pytorch中有一个大小为(24, 2, 224, 224)的张量。

  • 24 =批量大小
  • 2 =表示前景和的矩阵 背景
  • 224 =图像高度尺寸
  • 224 =图像宽度 尺寸

这是执行二进制分段的CNN的输出。在2个矩阵的每个单元格中存储该像素为前景或背景的概率:[n][0][h][w] + [n][1][h][w] = 1用于每个坐标

我想将它重塑为一个大小为(24, 1, 224, 224)的张量。根据概率较高的矩阵,新图层中的值应为01

我该怎么做?我应该使用哪种功能?

1 个答案:

答案 0 :(得分:1)

使用torch.argmax()(对于PyTorch +0.4):

prediction = torch.argmax(tensor, dim=1) # with 'dim' the considered dimension 
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)

如果PyTorch版本低于0.4.0,可以使用tensor.max()返回最大值及其索引(但不能在索引值上区分):

_, prediction = tensor.max(dim=1)
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)