Pytorch:批处理中每个图像的特定于文件的操作

时间:2021-05-18 17:55:17

标签: python pytorch pytorch-dataloader

我有一个图像数据集,每个图像都有一个附加属性“channel_no”。每个图像都应该根据它的channel_no用nn层进行处理:

 images with channel_no=1 have to be processed with layer1
 images with channel_no=2 have to be processed with layer2
 images with channel_no=3 have to be processed with layer3
etc...

问题是当batch包含多张图片时,forward()函数会得到一个以该批图片为输入的torch张量,每张图片都有不同的channel_no。所以不清楚如何分别处理每张图片。

这是批处理只有 1 张图像的情况的代码:

class Net(nn.Module):
    def __init__ (self, weight):
        super(Net, self).__init__()

        self.layer1 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.layer2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.layer3 = nn.Linear(hidden_sizes[0], hidden_sizes[1])

        self.outp = nn.Linear(hidden_sizes[1], output_size)
        
    def forward(self, x, channel_no):
        channel_no = channel_no[0] #extract channel_no from the batch list

        x = x.view(-1,hidden_sizes[0])

        if channel_no == 1: x = F.relu(self.layer1(x))
        if channel_no == 2: x = F.relu(self.layer2(x))
        if channel_no == 3: x = F.relu(self.layer3(x))

        x = torch.sigmoid(self.outp(x))

        return x    

是否可以使用批量大小 > 1 单独处理每个图像?

1 个答案:

答案 0 :(得分:1)

要单独处理图像,您可能需要单独的张量。我不确定是否有一种快速的方法可以做到这一点,但是您可以在批处理维度中拆分张量以获得单个图像张量,然后遍历它们以按通道号对它们进行排序。然后将每组具有相同通道号的图像加入一个新的张量并对该张量进行特殊处理。

相关问题