多热标签编码

时间:2019-02-21 21:07:22

标签: python tensorflow

我是Tensorflow的新手。我有一个图像数据集,其中一个图像带有多个标签。据我了解,我需要使用tf.losses.sigmoid_cross_entropy()。我尝试将tf.one_hot应用于标签,但是当我尝试将它们传递给损失函数时,出现错误,形状不兼容。我该如何解决?

1 个答案:

答案 0 :(得分:0)

您对tf.losses.sigmoid_cross_entropy的看法是正确的。您需要做的就是用tf.one_hot包装tf.reduce_max以减少尺寸。

tf.reduce_max(tf.one_hot(labels, num_classes, dtype=tf.int32), axis=0)

那应该返回形状(num_classes,)的张量,正是您的损失函数所需要的。