我是计算机视觉和深度学习的新手。我正在尝试使用 Resnet50 作为编码器来训练这个 Unet 模型 https://github.com/kevinlu1211/pytorch-unet-resnet-50-encoder。我想以一种方式实现它,我传递两个 rgb 图像,这些图像首先由 resnet50 处理,然后在传递给解码器之前连接层。我尝试这样做并将代码中的 n_classes 更改为 3 以输出 3 通道 rgb 图像,就像输入一样,但它给了我一个扭曲的图像 like this,我不明白为什么。请帮我解决这个问题。
我修改的代码中通过 resnet50 处理两个 rgb 输入的部分在这里 -
for i, block in enumerate(self.down_blocks, 2): # for all the down blocks
x = block(x)
if i == (UNetWithResnet50Encoder.DEPTH - 1):
continue
pre_pools[f"layer_{i}"] = x ## creating all the down sampling layers
pre_pools_inp2 = dict()
pre_pools_inp2[f"layer_0"] = y
y = self.input_block(y) #
pre_pools_inp2[f"layer_1"] = y
y = self.input_pool(y)
for i, block in enumerate(self.down_blocks, 2): # for all the down blocks
y = block(y)
if i == (UNetWithResnet50Encoder.DEPTH - 1):
continue
pre_pools_inp2[f"layer_{i}"] = y ## creating all the down sampling layers
x = torch.cat([x,y],1)
x = self.bridge(x) # this is now the bridge between down sampling and up sampling
for i, block in enumerate(self.up_blocks, 1):
key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" # now using that bridge for upsampling f
x = block(x, pre_pools[key])
output_feature_map = x
x = self.out(x)
del pre_pools
if with_output_feature_map:
return x, output_feature_map
else:
return x