使用tf.nn.sigmoid()后如何获得多标签输出

时间:2017-09-07 07:31:14

标签: python tensorflow tensorflow-serving

目的:

我希望在预测时直接在tensorflow服务中获取标签名称,我的问题是如何将pred = tf.nn.sigmoid(last_layer_output)转移到真实标签名称?

问题描述:

我知道如何处理多类问题:

CLASSES = tf.constant(['a', 'b', 'c', 'd', 'e'])

pred = tf.nn.softmax(last_layer_output)

# pred pretty similar to:
pred = [[0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1]]

classes_value = tf.argmax(pred, 1)
classes_name = tf.gather(CLASSES, classes_value)
# classes_name: [b'b' b'd' b'e']
# batch_size = 3

所以classes_name就是我想要的,我可以在tensorflow-serving中使用它设计签名,当我预测时,我可以获得最终标签。

但是在多标签问题上如何做到这一点?

e.g。

CLASSES = tf.constant(['a', 'b', 'c', 'd', 'e'])

pred = tf.nn.sigmoid(last_layer_output)
pred = tf.round(pred)

# pred pretty similar to:
pred = [[0, 1, 1, 0, 1],     # 'b','c','e'
        [1, 0, 0, 1, 0],     # 'a','d'
        [1, 1, 1, 0, 1]]     # 'a','b','c','e'

如何将pred转移到标签名称?我无法在sees.run()之后或使用其他类似numpy的API执行此操作,因为这是针对tensorflow服务的,我认为我需要使用tf方法来执行此操作。

任何建议都表示赞赏!

1 个答案:

答案 0 :(得分:0)

您应该首先定义许多类的给定概率,您要返回哪些类。例如,如果此类的概率高于0.5。

然后你可以使用tf.where和适当的条件来获取索引,然后使用相同的tf.gather来获得类。

像这样:

indicies = tf.where(tf.greater(pred, 0.5))
classes = tf.gather(CLASSES, indicies[:, 1])

然后你需要使用indicies[:, 0]重新组织它,告诉你哪个类的例子来自。

另外,你必须明白答案的正确形式是SparseTensor,它不受服务等支持。所以你可能想要返回两个张量[字符串和指示符,哪个字符串是哪个例子]并处理你的客户方。