多个标签的张量流预测

时间:2016-04-11 14:51:15

标签: tensorflow

如何在张量流图中获得预测向量 predict = tf.argmax(y)? (因为argmax仅适用于softmax分类器)

我有一个多标签分类问题因此我需要类似的东西:

predictions = [1. if prob > 0.5 else 0. for prob in y]

1 个答案:

答案 0 :(得分:4)

希望这会有所帮助:

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
prob = tf.constant(np.random.rand(10))
predictions = tf.select(prob > 0.5, tf.ones_like(prob), tf.zeros_like(prob))
print(predictions.eval())