dataset = slim.dataset.Dataset(...)
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, ..._
image, labels = provider.get(['image', 'label')
让我们说,对于数据集A中的示例,labels
可以是[1, 2, 1, 3]
。但是,由于某种原因(例如,由于数据集B),我想将标签ID映射到其他值。映射可能如下所示。
# {old_label: target_label}
mapping = {0: 0, 1: 2, 2: 2, 3: 2, 4: 2, 5: 3, 6: 1}
目前,我猜两种方式:
- tf.data.Dataset
似乎有一个map(map_func)
函数,每个示例都应该通过,这可能是解决方案。但是,我对slim.dataset.Dataset
更为熟悉。 slim.dataset.Dataset
是否有类似的技巧?
- 我想知道我是否可以简单地将一些映射函数应用于张量label
,例如:
new_labels = tf.map_fn(lambda x: x+1, labels, dtype=tf.int32)
# labels = [1 2 1 3] --> new_labels = [2 3 2 4]. This works.
new_labels = tf.map_fn(lambda x: mapping[x], labels, dtype=tf.int32)
# I wished but this does not work!
然而,下面没有工作,这就是我需要的。有人可以建议吗?
答案 0 :(得分:0)
我认为你可以试试tf.contrib.lookup:
keys = list(mapping.keys())
values = [mapping[k] for k in keys]
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values, key_dtype=tf.int64, value_dtype=tf.int64), -1
)
new_labels = table.lookup(labels)
sess=tf.Session()
sess.run(table.init)
print(sess.run(new_labels))