比如说,我在 PyTorch 中有一批图像。对于每个图像,我还有一个像素位置,比如 (x, y)
。可以使用 img[x, y]
读取一张图像的像素值。我正在尝试读取批处理中每个图像的像素值。请参阅下面的代码片段:
import torch
# create tensors to represent random images in torch format
img_1 = torch.rand(1, 200, 300)
img_2 = torch.rand(1, 200, 300)
img_3 = torch.rand(1, 200, 300)
img_4 = torch.rand(1, 200, 300)
# for each image, x-y value are know, so creating a tuple
img1_xy = (0, 10, 70)
img2_xy = (0, 40, 20)
img3_xy = (0, 30, 50)
img4_xy = (0, 80, 60)
# this is what I am doing right now
imgs = [img_1, img_2, img_3, img_4]
imgs_xy = [img1_xy, img2_xy, img3_xy, img4_xy]
x = [img[xy] for img, xy in zip(imgs, imgs_xy)]
x = torch.as_tensor(x)
(x, y)
是已知的。但是,我必须创建一个包含另外一个元素的元组,即 0 以确保该元组与图像的形状相匹配。有什么优雅的方式吗?tuple
,我们不能使用张量然后获取像素值吗?img_batch = torch.cat((img_1, img_2, img_3, img_4))
。但是元组呢?答案 0 :(得分:1)
您可以连接图像以形成 (4, 200, 300)
形堆叠张量。然后,我们可以使用每个图像的已知 (x, y)
对索引到它,如下所示:我们需要 [0, x1, y1]
用于第一幅图像,[1, x2, y2]
用于第二幅图像,[2, x3, y3]
用于第三幅图像等等。这些可以通过“花式索引”来实现:
# stacking as you did
>>> stacked_imgs = torch.cat(imgs)
>>> stacked_imgs.shape
(4, 200, 300)
# no need for 0s in front
>>> imgs_xy = [(10, 70), (40, 20), (30, 50), (80, 60)]
# need xs together and ys together: take transpose of `imgs_xy`
>>> inds_x, inds_y = torch.tensor(imgs_xy).T
>>> inds_x
tensor([10, 40, 30, 80])
>>> inds_y
tensor([70, 20, 50, 60])
# now we index into the batch
>>> num_imgs = len(imgs)
>>> result = stacked_imgs[range(num_imgs), inds_x, inds_y]
>>> result
tensor([0.5359, 0.4863, 0.6942, 0.6071])
我们可以检查结果:
>>> torch.tensor([img[0, x, y] for img, (x, y) in zip(imgs, imgs_xy)])
tensor([0.5359, 0.4863, 0.6942, 0.6071])
回答您的问题:
1:由于我们堆叠了图像,所以这个问题得到了缓解,我们改为使用 range(4)
来索引每个单独的图像。
2:是的,我们确实将 x, y
位置转换为张量。
3: 将它们分离成张量后,我们直接用它们进行索引。