tensorflow TFRecord k-hot编码

时间:2018-05-20 12:15:24

标签: python tensorflow tfrecord

我正在尝试按照本教程创建TFRecord格式的训练数据集:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md用于API检测。

但是,我想使用k-hot编码,而不是使用一个热编码。例如,我没有[0 0 0 1 0]标签,而是可以[0 1 0 1 0]进行多分类。我想知道如何使用TFRecord格式。如果我使用2-hot编码,是否必须创建两个tf.train.example? (使用两次相同的bouding box坐标)还是有另一种方式? (例如使用'image/object/class/text': dataset_util.bytes_list_feature(classes_text)'image/object/class/text2': dataset_util.bytes_list_feature(classes_text2))

1 个答案:

答案 0 :(得分:0)

鉴于您有[0、1、2]和10个类的标签列表,因此需要

def int64_feature(value):
    if type(value) != list:
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

然后将标签传递给tf.Example作为功能之一

'label': int64_feature(label)

在那之后,当您在训练期间解析数据集时,您可以这样删除标签:

tf.reduce_max(tf.one_hot(labels, num_classes, dtype=tf.int32), axis=0)

哪个给

[1 1 1 0 0 0 0 0 0 0]