将张量中的每个值映射到新值,具体取决于其在另一个张量中的索引

时间:2019-10-31 14:28:30

标签: python tensorflow2.0

我正在使用Tensorflow 2.0。我有一个(256 x 256)张量,范围在0到255之间,我们称之为gray。每个值都是10个唯一值之一。我有另一个张量uniqueValues,其中包含10个唯一值。我正在尝试找到一种创建新(256 x 256)张量result的方法,其中result的第i,j个值等于uniqueValues的索引,其中gray的第i,j个值出现:

  gray = tf.image.decode_png(png, channels=1)
  flattened = tf.reshape(gray, [-1])

  # creates a tensor of length 10 holding each unique value
  uniqueValues, idx = tf.unique(flattened)
  gray = tf.reshape(gray, (256, 256))

  # Convert the gray (256x256) tensor...
  # [[255 255 255 ... 255
  # ...
  #  255 15 15 ... 200]]

  # using 'uniqueValues'...
  # [ 15 200 255 ]

  # To result (256x256) tensor...
  # [[2 2 2 ... 2
  # ...
  #  2 0 0 ... 1 ]]

  # possibly using the tf.map_fn?
  result = tf.map_fn( # how to do this part?, gray)

  # now I can create the one-hot version of gray
  oneHot = tf.one_hot(result, 10)

曾经玩过tf.wheretf.equal,但我似乎无法使其正常工作。

1 个答案:

答案 0 :(得分:0)

以防万一其他人为此而苦恼,这是基于使用StaticHashTable的解决方案:

import tensorflow as tf

# define mapping from keys to values...
lookupTable = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant([0, 76, 78, 117, 178, 202, 211, 225, 242, 255]),
        values=tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    ),
    default_value=tf.constant(0)
)

  gray = tf.image.decode_png(png, channels=1)

  # cast source from uint8 to int32 because StaticHashMap only works 
  # with restricted set of types
  gray = tf.dtypes.cast(tf.reshape(gray, (256, 256)), tf.int32)

  # voila, works like a charm!
  result = tf.map_fn(lambda x: lookupTable.lookup(x), gray)
  oneHot = tf.one_hot(result, 10)