在tensorflow中将标签映射到RGB

时间:2018-10-01 17:35:24

标签: tensorflow

我们如何在tensorflow中将整数标签映射到不同的伪颜色。我尝试过:

with sess.as_default():
a = tf.random_uniform(shape=[3,3,3],minval=0,maxval=5,dtype=tf.int32)
keys = [0,1,2,3,4,5]
values = [0,10,20,30,40,50]
table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1)
out = table.lookup(a)
table.init.run()
print(out.eval())

但是此选项无法解决我的用例。对于输入张量中的每个唯一标签[height x width xchannels = 1],我想将其映射到[height x width xchannels = 3]

labels                LUT
-------               ----
   0                 [0,0,0]
   1                 [128,64,128]
   2                 [64,128,256]
   3                 [255,64,128]

感谢一些帮助。

2 个答案:

答案 0 :(得分:1)

您可以使用tf.nn.embedding_lookuplabels映射到LUT表:

LUT = tf.constant([[0,0,0],[128,64,128],[64,128,256],[255,64,128]], tf.int32)
labels = [0, 1, 2, 3]
out = tf.nn.embedding_lookup(table, labels)

with tf.Session() as sess:
   print(sess.run(out))

#[[  0   0   0]
# [128  64 128]
# [ 64 128 256]
# [255  64 128]]

答案 1 :(得分:0)

使用Google Colab的完整示例(底部的链接)

import tensorflow as tf
from google.colab import drive
from matplotlib import pyplot as plt
import numpy as np
import cv2

drive.mount('/content/drive')

sess = tf.Session()

# an input image

img = cv2.imread('/content/drive/My Drive/Colab Notebooks/test.png')
img = img[:, :, 0]
print(img.shape)

label = tf.placeholder(tf.int32,shape=[None,None])

LUTval = np.zeros((256,3)) # lookup table must span all possible labels, classes are 0 to 19, 255 is unlabelled in this example
LUTval[0:19, :] = [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156],[190, 153, 153],[153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35],[152, 251, 152],[70, 130, 180],[220, 20, 60],[255, 0, 0],[0, 0, 142],[0, 0, 70],[0, 60, 100],[0, 80, 100],[0, 0, 230],[119, 11, 32]]
LUT = tf.constant(LUTval, tf.int32)
col = tf.nn.embedding_lookup(LUT, label)

res_col = sess.run(col, feed_dict = {label: img})

print(res_col.shape)

#img2 = img[:,:,::-1]
plt.imshow(cv2.cvtColor(np.array(res_col).astype(np.uint8), cv2.COLOR_BGR2RGB))

Snippet on Google Colab Test Image