在pytorch中输入3和1通道输入到网络?

时间:2018-05-22 15:10:04

标签: pytorch

我的数据集主要包含3个频道图像,但我也有一些1频道图像,是否可以训练一个同时包含3个频道和1个频道作为输入的网络?

欢迎任何建议,提前致谢,

2 个答案:

答案 0 :(得分:1)

您可以通过检查尺寸来检测灰度图像,并应用一些转换以获得3个通道。

将图像从灰度转换为RGB似乎比仅仅在通道上复制图像三次更好。

如果您安装了cv2.cvtColor(gray_img, cv.CV_GRAY2RGB),则可以opencv-python执行此操作。

如果您想要一个干净的实施,您可以使用新的torchvision.transform扩展Transform,自动完成此工作。

答案 1 :(得分:0)

加载图像并将其转换为RGB:

image = Image.open(path).convert('RGB')

您可以将此行添加到数据集的__getitem__方法中。它将灰度图像转换为RGB,而仅将彩色图像保持为RGB。