tf.nn.in_top_k:目标超出范围

时间:2016-06-02 09:18:58

标签: tensorflow

我从张量流中调整cifar10网络,以解决我自己的分类问题。我已经训练了网络,现在我尝试使用cifar10_eval.py

评估训练模型
top_k_op = tf.nn.in_top_k(logits, labels, 1)

但我得到以下错误。经过进一步调查,目标指数在2,3和4之间变化

tensorflow.python.framework.errors.InvalidArgumentError: targets[3] is out of range

到现在为止,我知道我的标签 - Tensor出了问题。它是一个int32-Tensor,形状(50,)如下所示。

labels = {Tensor} Tensor("batch_processing/Reshape_1:0", shape=(50,), dtype=int32, device=/device:CPU:0)

我的数据集只有2个类/标签。也许这可能是问题所在。有谁知道,问题是什么?

2 个答案:

答案 0 :(得分:7)

总结一下,函数tf.nn.in_top_k(predictions, targets, k)(参见doc)有参数:

  • 预测:shape [batch_size, num_classes],输入float32
  • targets(正确的标签):shape [batch_size],输入int32或int64

InvalidArgumentError: targets[i] is out of range元素targets[i]超出范围时,该函数会引发错误predictions[i]

例如,有2个类(num_classes=2)和targets=[1, 3]。 使用这些目标时,您会看到错误InvalidArgumentError: targets[1] is out of range,因为targets[1] = 3超出predictions[1]范围仅为2的范围。

要检查您的labels是否正确,您可以打印最大值:

labels = ...
labels_max = tf.reduce_max(labels)

sess = tf.Session()
print sess.run(labels_max)

如果打印的值高于num_classes,则表示您遇到问题。

答案 1 :(得分:1)

因此,如果您以一种热编码的方式进行预测,则目标必须是放置一(1)种热的正确索引。 例如:

bb=tf.nn.in_top_k([[0,1],[1,0],[0,1]]  ,   [1,1,1],1)

将返回:

[正确错误正确]

因此,您必须将可能是一个热门的目标转换为该索引方法

脾气暴躁:

targetsindex = np.argmax(targets, axis=1)

张量:

targetsindex = tf.argmax(targets, axis=0)