在pytorch中使用4通道图像进行分类

时间:2020-01-07 04:55:33

标签: classification pytorch

我有一些带有标签的灰度和彩色图像。我想将这种灰色和彩色图像(4通道)结合起来,并使用4通道图像运行转移学习。该怎么做?

2 个答案:

答案 0 :(得分:0)

您当前的模型期望RGB输入只有三个通道,因此其第一转换层具有in_channels=3,并且该第一层的weight的形状为out_channels x 3 < / strong> x kernel_height x kernel_width
为了容纳4通道输入,您需要将第一层更改为in_channels=4和形状为weight x 4 x {{1} } x out_channels。您还希望保留学习的权重,因此您应该将新的kernel_height初始化为与旧的权重相同,除了增加的权重中的微小噪音之外。

答案 1 :(得分:0)

如果我正确理解了这个问题,则希望将1个通道的图像和3个通道的图像组合在一起,并获得4个通道的图像并将其用作输入。

如果这是您要执行的操作,则可以使用torch.cat()。

一些示例代码,用于加载两个图像并沿通道维度进行组合

import numpy as np
import torch
from PIL import Image

image_rgb = Image.open(path_to_rgb_image)
image_rgb_tensor = torch.from_numpy(np.array(image_rgb))
image_rgb.close()

image_grayscale = Image.open(path_to_grayscale_image))
image_grayscale_tensor = troch.from_numpy(np.array(image_grayscale))
image_grayscale.close()

image_input = torch.cat([image_rgb_tensor, image_grayscale_tensor], dim=2)

我假设您要使用的灰度图像转换为形状为[..., ..., 1]的张量,而rgb图像转换为[..., ..., 3]