我正在处理3D CT医学数据,并且试图将其切成可输入UNet模型的2D切片。
我已将数据加载到割炬数据加载器中,并且每次迭代当前都会生成4D张量:
for batch_index, batch_samples in enumerate(train_loader):
data, target = batch_samples['image'].float().cuda(), batch_samples['label'].float().cuda()
print(data.size())
torch.Size([1, 333, 512, 512])
torch.Size([1, 356, 512, 512])
例如这个。我想遍历333个切片,然后遍历356个切片,以使模型每次都接收到割炬尺寸[1、1、512、512]。
我希望是这样的:
for x in (data[:,x,:,:]):
可以工作,但是它说我需要先定义x。如何遍历火炬张量中的特定尺寸?
答案 0 :(得分:1)
只需指定尺寸:
for i in range(data.shape[1]): # dim=1
x = data[:, i, :, :]
# [...]
如果您确实需要额外的尺寸,只需添加.unsqueeze()
:
d = 1
for i in range(data.shape[d]): # dim=1
x = data[:, i, :, :].unsqueeze(d) # same dim=1
# [...]