生成预测列表

时间:2018-11-06 21:41:38

标签: python tensorflow neural-network recurrent-neural-network

使用Tensorflow,有没有一种方法可以输出网络预测?

我的输出一直使用12个类的一个热表示

[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]
etc...

当尝试从模型中获取给定输入的预测时,我运行了以下代码

    prediction=tf.argmax(y,1)
    best = sess.run([prediction],feed_dict={x: batch_x, y: batch_y,
                                  seqlen: batch_seqlen})
    print("Prediction: ")
    print(best)

运行此代码并打印预测时,我的输出是:

[array([1, 5, 7, 7, 7, 4, 7, 9, 4, 4, 9, 6, 7, 8, 3, 2], dtype=int64)]
我输入的批量大小为16,所以有16个输出确实有意义。但是,这些都不是One Hot表示形式(不确定tensorflow的输出是否打算插入索引中,因此1实际上是onehot的某种形式

对于每个特定的X,是否都有一种方法可以创建排名的预测列表,那么在给定X的情况下,模型最有可能找到什么?

这有意义吗?

1 个答案:

答案 0 :(得分:1)

您正在采用1热门向量的tf.argmax,因此这就是为什么看到索引而不是您期望的1热门向量的原因。

要获取分类预测的排序列表,您可以采用预测向量并应用values, indices = tf.nn.top_k(prediction)values将是您的预测以降序排列,indices将是那些{{ 1}}'索引。