one-hot 编码语义分割掩码后丢失特征

时间:2021-03-16 02:45:52

标签: python tensorflow image-processing image-segmentation

我正在使用 U-Net 进行一个简单的人脸分割项目;我有 11 个类,我的输入标签是 .png RGB 分割掩码,所以我试图对它们进行单热编码(例如:将一个 RGB 掩码转换为 11 个二进制掩码的分割掩码)。

但是,当我尝试对标签进行单热编码然后将它们解码回 RGB 时,似乎我的代码丢弃了掩码中的特征,但我不知道为什么。如果有人可以查看代码,或者更好地提出一种更优化的方法来进行 rgb-onehot 转换(反之亦然),那就太好了:)

Original mask vs mask after onehot-encoding and back

这是我用来重现这个问题的代码:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

# Test on one sample
mask_path = "data/HELEN_warped_validation/colourized_labels/10405146_1.png"

mask = tf.io.read_file(mask_path)
mask = tf.io.decode_png(mask)

color_dict = {
    0: (0, 0, 0),  # 0=background
    1: (0, 255, 247),  # 1=face skin
    2: (0, 128, 0),  # 2=left eyebrow
    3: (128, 128, 0),  # 3=right eyebrow
    4: (0, 0, 128),  # 4=left eye
    5: (128, 0, 128),  # 5=right eye
    6: (0, 68, 255),  # 6=nose
    7: (107, 107, 255),  # 7=upper lip
    8: (0, 255, 85),  # 8=inner mouth
    9: (192, 0, 0),  # 9=lower lip
    10: (0, 0, 255)  # 10=hair
}


def rgb_to_onehot(rgb_arr, color_dict):
    num_classes = len(color_dict)
    shape = rgb_arr.shape[:2] + (num_classes,)
    arr = np.zeros(shape, dtype=np.int8)
    for i, cls in enumerate(color_dict):
        arr[:, :, i] = np.all(rgb_arr.reshape((-1, 3)) == color_dict[i], axis=1).reshape(shape[:2])
    return arr


def onehot_to_rgb(onehot, color_dict):
    single_layer = np.argmax(onehot, axis=-1)
    output = np.zeros(onehot.shape[:2] + (3,))
    for k in color_dict.keys():
        output[single_layer == k] = color_dict[k]
    return np.uint8(output)


print(mask.shape)         # 512, 512, 3
onehot_mask = rgb_to_onehot(mask.numpy(), color_dict)
print(onehot_mask.shape)  # 512, 512, 11
rgb_mask = onehot_to_rgb(onehot_mask, color_dict)
print(rgb_mask.shape)     # 512, 512, 3

# Visualise encoded mask
plt.subplot(1, 2, 1)
plt.title('Input Image')
plt.imshow(mask)

plt.subplot(1, 2, 2)
plt.title('RGB Mask')
plt.imshow(rgb_mask)

plt.show()

0 个答案:

没有答案