我正在尝试按照本教程创建TFRecord格式的训练数据集:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md用于API检测。
但是,我想使用k-hot编码,而不是使用一个热编码。例如,我没有[0 0 0 1 0]标签,而是可以[0 1 0 1 0]进行多分类。我想知道如何使用TFRecord格式。如果我使用2-hot编码,是否必须创建两个tf.train.example
? (使用两次相同的bouding box坐标)还是有另一种方式? (例如使用'image/object/class/text': dataset_util.bytes_list_feature(classes_text)
和'image/object/class/text2': dataset_util.bytes_list_feature(classes_text2))
?
答案 0 :(得分:0)
鉴于您有[0、1、2]和10个类的标签列表,因此需要
def int64_feature(value):
if type(value) != list:
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
然后将标签传递给tf.Example作为功能之一
'label': int64_feature(label)
在那之后,当您在训练期间解析数据集时,您可以这样删除标签:
tf.reduce_max(tf.one_hot(labels, num_classes, dtype=tf.int32), axis=0)
哪个给
[1 1 1 0 0 0 0 0 0 0]