Tensorflow:如果灰度将图像转换为RGB

时间:2019-02-26 11:17:16

标签: python tensorflow

我有一个rgb和灰度图像的数据集。遍历数据集时,我想检测图像是否为灰度图像,以便可以将其转换为rgb。我想使用tf.shape(image)来检测图像的尺寸。对于rgb图像,我得到类似[1, 100, 100, 3]的图像。对于灰度图像,该函数返回例如[1, 100, 100]。我想使用len(tf.shape(image))来检测它的长度是4(= rgb)还是长度3(=灰度)。那没用。

这是到目前为止我的无效代码:

def process_image(image):
    # Convert numpy array to tensor
    image = tf.convert_to_tensor(image, dtype=tf.uint8)
    # Take care of grayscale images
    dims = len(tf.shape(image))
    if dims == 3:
        image = np.expand_dims(image, axis=3)
        image = tf.image.grayscale_to_rgb(image)
    return image

是否存在将灰度图像转换为rgb的替代方法?

2 个答案:

答案 0 :(得分:2)

我遇到了一个非常相似的问题,我想一次性加载 rgb 和灰度图像。 Tensorflow 支持在读取图像时设置通道号。因此,如果图像具有不同数量的通道,这可能就是您要查找的内容:

# to get greyscale:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=1)

# to get rgb:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=3)

-> 您甚至可以在同一图像和内部 tf.data.Dataset 映射上进行!

您现在必须设置 channels 变量以匹配您需要的形状,因此所有加载的图像都将具有该形状。比你可以无条件地重塑。

这也允许你在 Tensorflow 中直接将灰度图像加载到 RGB。举个例子:

    >> a = Image.open(r"Path/to/rgb_img.JPG")
    >> np.array(a).shape
    (666, 1050, 3)
    >> a = a.convert('L')
    >> np.array(a).shape
    (666, 1050)
    >> b = np.array(a)
    >> im = Image.fromarray(b) 
    >> im.save(r"Path/to/now_it_is_greyscale.jpg")
    >> raw_img = tf.io.read_file(r"Path/to/now_it_is_greyscale.jpg")
    >> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=3)
    >> img.shape
    TensorShape([666, 1050, 3])
    >> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=1)
    >> img.shape
    TensorShape([666, 1050, 1])

如果得到 expand_animations = False,请使用 ValueError: 'images' contains no shape.!请参阅:https://stackoverflow.com/a/59944421/9621080

答案 1 :(得分:1)

您可以为此使用以下功能:

import tensorflow as tf

def process_image(image):
    image = tf.convert_to_tensor(image, dtype=tf.uint8)
    image_rgb =  tf.cond(tf.rank(image) < 4,
                         lambda: tf.image.grayscale_to_rgb(tf.expand_dims(image, -1)),
                         lambda: tf.identity(image))
    # Add shape information
    s = image.shape
    image_rgb.set_shape(s)
    if s.ndims is not None and s.ndims < 4:
        image_rgb.set_shape(s.concatenate(3))
    return image_rgb