Tensorflow:使用slim.dataset.Dataset时,有没有办法将标签ID值映射到其他值?

时间:2018-04-26 22:55:45

标签: tensorflow tensorflow-slim

dataset = slim.dataset.Dataset(...)
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, ..._
image, labels = provider.get(['image', 'label')

让我们说,对于数据集A中的示例,labels可以是[1, 2, 1, 3]。但是,由于某种原因(例如,由于数据集B),我想将标签ID映射到其他值。映射可能如下所示。

# {old_label: target_label}
mapping = {0: 0, 1: 2, 2: 2, 3: 2, 4: 2, 5: 3, 6: 1}

目前,我猜两种方式:

- tf.data.Dataset似乎有一个map(map_func)函数,每个示例都应该通过,这可能是解决方案。但是,我对slim.dataset.Dataset更为熟悉。 slim.dataset.Dataset是否有类似的技巧?

- 我想知道我是否可以简单地将一些映射函数应用于张量label,例如:

new_labels = tf.map_fn(lambda x: x+1, labels, dtype=tf.int32)
# labels = [1 2 1 3] --> new_labels = [2 3 2 4]. This works.

new_labels = tf.map_fn(lambda x: mapping[x], labels, dtype=tf.int32)
# I wished but this does not work!

然而,下面没有工作,这就是我需要的。有人可以建议吗?

1 个答案:

答案 0 :(得分:0)

我认为你可以试试tf.contrib.lookup

keys = list(mapping.keys())
values = [mapping[k] for k in keys]
table = tf.contrib.lookup.HashTable(
  tf.contrib.lookup.KeyValueTensorInitializer(keys, values, key_dtype=tf.int64, value_dtype=tf.int64), -1
)
new_labels = table.lookup(labels)
sess=tf.Session()
sess.run(table.init)
print(sess.run(new_labels))