如何将3D火炬张量切成2D切片

时间:2020-07-18 12:26:56

标签: iteration pytorch torch medical dataloader

我正在处理3D CT医学数据,并且试图将其切成可输入UNet模型的2D切片。

我已将数据加载到割炬数据加载器中,并且每次迭代当前都会生成4D张量:

for batch_index, batch_samples in enumerate(train_loader):
    data, target = batch_samples['image'].float().cuda(), batch_samples['label'].float().cuda()
    print(data.size())
torch.Size([1, 333, 512, 512])
torch.Size([1, 356, 512, 512])

例如这个。我想遍历333个切片,然后遍历356个切片,以使模型每次都接收到割炬尺寸[1、1、512、512]。

我希望是这样的:

for x in (data[:,x,:,:]):

可以工作,但是它说我需要先定义x。如何遍历火炬张量中的特定尺寸?

1 个答案:

答案 0 :(得分:1)

只需指定尺寸:

for i in range(data.shape[1]):  # dim=1
    x = data[:, i, :, :]
    # [...]

如果您确实需要额外的尺寸,只需添加.unsqueeze()

d = 1
for i in range(data.shape[d]):         # dim=1
    x = data[:, i, :, :].unsqueeze(d)  # same dim=1
    # [...]