在PyTorch中使用3D张量索引对4D张量进行切片

时间:2019-10-20 09:21:08

标签: pytorch slice tensor

我有一个 4D张量(恰好是由三个批次的堆栈组成的56x56图像,每个批次具有16张图像),尺寸为 [ 16、3、56、56] 。我的目标是为每个像素选择这三批中的正确批处理(其索引图的大小为 [16、56、56] ),然后获取所需的图像。

现在,我要选择这三批图像中的特定图像批次,其图像的值如

       [[[ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  2,  2,  ...,  0,  0,  0]],

        [[ 0,  2,  0,  ...,  1,  1,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  2,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  0,  0]]]

因此对于0,将从第一批中选择值,其中1和2表示我要从第二和第三批中选择值。

以下是一些索引的可视化效果,每种颜色表示另一批次。

enter image description here

enter image description here

我试图转置4D张量以匹配索引的尺寸,但是没有用。它所做的就是给我一份我尝试选择的尺寸的副本。意思

tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())

输出

torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])

我需要的形状是

torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])

,例如,如果我尝试仅使用

为批次中的第一个图像选择正确的值
fourD[0,indices].size()

我得到一个像

的形状
torch.Size([16, 56, 56, 56, 56])

更不用说当我在整个张量上尝试时出现内存不足错误。

我非常感谢使用这些索引为图像中的每个像素选择这三批中的任意一个

注意:

我尝试了该选项

outs[indices[:,None,:,:]].size()

返回

torch.Size([16, 1, 56, 56, 3, 56, 56])

编辑:torch.take并没有太大帮助,因为它将输入张量视为一维数组。

1 个答案:

答案 0 :(得分:1)

结果证明PyTorch中有一个功能具有我要搜索的功能。

torch.gather(fourD, 1, indices.unsqueeze(1)) 

工作完成了。

Here很好地解释了聚集的作用。