无法使用tf.metrics.recall

时间:2019-06-03 07:14:21

标签: python-3.x tensorflow deep-learning

我对tensorflow非常陌生。 我只是想了解如何使用tf.metrics.recall

我正在做以下事情

true = tf.zeros([64, 1])
pred = tf.random_uniform([64,1], -1.0,1.0)
with tf.Session() as sess:
    t,p = sess.run([true,pred])
#     print(t)
#     print(p)
    rec, rec_op = tf.metrics.recall(labels=t, predictions=p)
    sess.run(rec_op,feed_dict={t: t,p: p})
    print(recall)

这给了我以下错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-43-7245c92d724d> in <module>
     25 #     print(p)
     26     rec, rec_op = tf.metrics.recall(labels=t, predictions=p)
---> 27     sess.run(rec_op,feed_dict={t: t,p: p})
     28     print(recall)

TypeError: unhashable type: 'numpy.ndarray'

请帮助我更好地理解这一点。 预先谢谢你

1 个答案:

答案 0 :(得分:0)

代码中的

标签和预测返回张量输出,它们是numpy数组。如果愿意,可以使用numpy或自己的实现来计算对它们的召回率。使用指标的好处是您可以仅使用tensorflow一次运行所有内容。

with tf.Session() as sess:
    rec, rec_op = tf.metrics.recall(labels=true, predictions=pred)
    batch_recall, _ = sess.run([rec, rec_op],feed_dict={t: t,p: p})
    print(recall)

请注意,您使用张量构造tf.metrics.recall。