tensorflow:考虑top-k预测值,预测多标签准确性

时间:2017-08-18 01:37:11

标签: tensorflow

我的设置是给定一组数据标签,其中一些数据标签可以有多个(每个例子的最大数量为1,用n表示):

1,0,0,0,1,0
0,0,1,0,1,1
....
1,1,1,0,0,0

当我预测我想看看前n个logits的索引是什么时,如果预测索引包含所有一个的标签索引,那么它是正确的预测。如何在tensorflow中实现呢?

1 个答案:

答案 0 :(得分:0)

您可以这样做:

#inputs 
labels = tf.constant([[1,1,0,0,0,0],[0,0,1,0,1,1]])
logits = tf.constant([[.6,.5,.5,.4,.2,0.1],[.05,.15,.2,.15,.05,.5]])
k = 2

# Predict the top-k for each row
topk, idx = tf.nn.top_k(logits, k)

# Sort the index for input to sparse_to_dense matrix 
idx, _ = tf.nn.top_k(-idx, k)

# Obtain the full indices
indices = tf.stack([tf.tile(tf.range(0, idx.get_shape()[0])[...,tf.newaxis], [1, k]), -idx], axis=2)
indices = tf.reshape(tf.squeeze(indices), [-1,2])

#convert them to dense matrix
pred_labels = tf.sparse_to_dense(indices, logits.get_shape(),   tf.ones(idx.get_shape()[0]*k))

#Calculate whether each row of labels contained in logits
sum1 = tf.reduce_sum(tf.cast(labels, tf.float32),1)
sum2 = tf.reduce_sum(tf.multiply(tf.cast(labels, tf.float32), tf.cast(pred_labels, tf.float32)), 1)
acc = tf.equal(sum1, sum2)
sess = tf.InteractiveSession()
print(sess.run(pred_labels))
sess.run(acc)

# Outputs
#[[ 1.  1.  1.  0.  0.  0.]
#[ 0.  1.  1.  0.  0.  1.]]

#[ True, False]