在 PyTorch 中从批处理张量索引元素

时间:2021-07-02 13:29:50

标签: python indexing pytorch tensor

比如说,我在 PyTorch 中有一批图像。对于每个图像,我还有一个像素位置,比如 (x, y)。可以使用 img[x, y] 读取一张图像的像素值。我正在尝试读取批处理中每个图像的像素值。请参阅下面的代码片段:

import torch

# create tensors to represent random images in torch format
img_1 = torch.rand(1, 200, 300)
img_2 = torch.rand(1, 200, 300)
img_3 = torch.rand(1, 200, 300)
img_4 = torch.rand(1, 200, 300)

# for each image, x-y value are know, so creating a tuple
img1_xy = (0, 10, 70)
img2_xy = (0, 40, 20)
img3_xy = (0, 30, 50)
img4_xy = (0, 80, 60)

# this is what I am doing right now
imgs = [img_1, img_2, img_3, img_4]
imgs_xy = [img1_xy, img2_xy, img3_xy, img4_xy]
x = [img[xy] for img, xy in zip(imgs, imgs_xy)]
x = torch.as_tensor(x)

我的顾虑和问题

  1. 在每个图像中,像素位置,即 (x, y) 是已知的。但是,我必须创建一个包含另外一个元素的元组,即 0 以确保该元组与图像的形状相匹配。有什么优雅的方式吗?
  2. 除了使用 tuple,我们不能使用张量然后获取像素值吗?
  3. 所有图像都可以连接成一个批次作为 img_batch = torch.cat((img_1, img_2, img_3, img_4))。但是元组呢?

1 个答案:

答案 0 :(得分:1)

您可以连接图像以形成 (4, 200, 300) 形堆叠张量。然后,我们可以使用每个图像的已知 (x, y) 对索引到它,如下所示:我们需要 [0, x1, y1] 用于第一幅图像,[1, x2, y2] 用于第二幅图像,[2, x3, y3] 用于第三幅图像等等。这些可以通过“花式索引”来实现:

# stacking as you did
>>> stacked_imgs = torch.cat(imgs)
>>> stacked_imgs.shape
(4, 200, 300)

# no need for 0s in front
>>> imgs_xy = [(10, 70), (40, 20), (30, 50), (80, 60)]

# need xs together and ys together: take transpose of `imgs_xy`
>>> inds_x, inds_y = torch.tensor(imgs_xy).T

>>> inds_x
tensor([10, 40, 30, 80])

>>> inds_y
tensor([70, 20, 50, 60])

# now we index into the batch
>>> num_imgs = len(imgs)
>>> result = stacked_imgs[range(num_imgs), inds_x, inds_y]
>>> result
tensor([0.5359, 0.4863, 0.6942, 0.6071])

我们可以检查结果:

>>> torch.tensor([img[0, x, y] for img, (x, y) in zip(imgs, imgs_xy)])

tensor([0.5359, 0.4863, 0.6942, 0.6071])

回答您的问题:

1:由于我们堆叠了图像,所以这个问题得到了缓解,我们改为使用 range(4) 来索引每个单独的图像。

2:是的,我们确实将 x, y 位置转换为张量。

3: 将它们分离成张量后,我们直接用它们进行索引。