我的设置是给定一组数据标签,其中一些数据标签可以有多个(每个例子的最大数量为1,用n表示):
1,0,0,0,1,0
0,0,1,0,1,1
....
1,1,1,0,0,0
当我预测我想看看前n个logits的索引是什么时,如果预测索引包含所有一个的标签索引,那么它是正确的预测。如何在tensorflow中实现呢?
答案 0 :(得分:0)
您可以这样做:
#inputs
labels = tf.constant([[1,1,0,0,0,0],[0,0,1,0,1,1]])
logits = tf.constant([[.6,.5,.5,.4,.2,0.1],[.05,.15,.2,.15,.05,.5]])
k = 2
# Predict the top-k for each row
topk, idx = tf.nn.top_k(logits, k)
# Sort the index for input to sparse_to_dense matrix
idx, _ = tf.nn.top_k(-idx, k)
# Obtain the full indices
indices = tf.stack([tf.tile(tf.range(0, idx.get_shape()[0])[...,tf.newaxis], [1, k]), -idx], axis=2)
indices = tf.reshape(tf.squeeze(indices), [-1,2])
#convert them to dense matrix
pred_labels = tf.sparse_to_dense(indices, logits.get_shape(), tf.ones(idx.get_shape()[0]*k))
#Calculate whether each row of labels contained in logits
sum1 = tf.reduce_sum(tf.cast(labels, tf.float32),1)
sum2 = tf.reduce_sum(tf.multiply(tf.cast(labels, tf.float32), tf.cast(pred_labels, tf.float32)), 1)
acc = tf.equal(sum1, sum2)
sess = tf.InteractiveSession()
print(sess.run(pred_labels))
sess.run(acc)
# Outputs
#[[ 1. 1. 1. 0. 0. 0.]
#[ 0. 1. 1. 0. 0. 1.]]
#[ True, False]