如何在PyTorch中将RGB图像编码为n_class个热张量

时间:2019-07-19 10:04:57

标签: deep-learning computer-vision pytorch torch torchvision

因此,我正在执行分割任务,我需要做的是将RGB图像转换为n通道一个热矩阵,用作U-Net模型的标签

我所做的是计算图像中的颜色。颜色或颜色等于类的数量。

我尝试在PerPixelClassMatrix函数中进行的操作是遍历Image,并制作一个n维尺寸为0和1s的矩阵,因为我具有每个像素的颜色和类。

import glob
from tqdm import tqdm
import numpy as np


class HotEncoder():
    def __init__(self, dir, extension, is_binary=True):
        self.dir = dir
        self.extension = extension
        self.is_binary = is_binary
        if is_binary:
            self.color = {(0, 0, 0): 1, (255, 255, 255): 2}
        else:
            self.color = dict()

    def gen_colors(self):
        """Iterates through the entire dataset and finds the total colours
            in the images so that they can be used to one hot the image matrix
            for the training data"""
        if self.is_binary:
            return self.color
        else:
            n_color=1
            images = glob.glob(self.dir + '/*.' + self.extension)
            for img in tqdm(images, desc="Generating Color Pallte to Hot Encode"):
                image = skimage.io.imread(img)
                shape_ = image.shape
                for x in range(shape_[0]):
                    for y in range(shape_[1]):
                        clr= tuple(image[x][y][:])
                        if clr not in self.color.keys():
                            self.color.update({n_color: clr})
                            n_color+=1
                        else:
                            pass
        return self.color

    def PerPixelClassMatrix(self, Image):
        """Takes an Image and returns a per pixel class
            identification map"""
        class_list= []
        class_mat= np.array([])
        shape_= Image.shape
        for x in range(shape_[0]):
            for y in range(shape_[1]):
                clr= tuple(Image[x][y][:])
                if clr in self.color.keys():
                    class_list.append(self.color[clr])
                else:
                    pass
        return class_list

我不想运行整个循环来生成一个n通道的热图像。有没有一种简单的方法来构造这种已知颜色的矩阵。

1 个答案:

答案 0 :(得分:0)

如果要计算图像分割损失,可以执行以下操作:

output = model(input)  # raw logit output in shape [1, 3, 512, 512]
loss = criterion(F.log_softmax(output,1), target)  # target in shape [1, 512, 512]

目标将包含带有掩码索引的标签[0, N)。我假设您输入的图像是3通道RGB。

Source of the answer,可以在附近找到示例。