如何在TensorFlow中使用带有shape =(1,1)的标签张量的tf.equal()?

时间:2018-01-15 18:39:45

标签: python python-3.x tensorflow

我试图评估单个图像。它到目前为止工作。我得到每个类的正确概率和正确的标签。但是当我尝试用tf.argmax(label, 1)来上课时,我总是得到课程" 0"。

...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
arg_lbl = tf.argmax(label, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
print(pre)
# [[2.0451562e-06 # class 0
# 6.1964911e-06   # class 1
# 4.1852250e-06   # class 2
# 9.9847549e-01   # class 3 - We have a winner :)
# 8.2492170e-07   # class 4
# 3.1969071e-06   # class 5
# 1.5037126e-03   # class 6
# 1.6847488e-07   # class 7
# 6.7177882e-07   # class 8
# 3.4959594e-06]] # class 9
print(lbl)
# [[3]]
print(a_log)
# [3]
print(a_lbl)
# [0] # Why i dont get "3"?
...

我总是得到" 0"对于每个数据点。我想继续使用tf.equal(),但标签的argmax值错误,当然这是不可能的。有什么想法吗?:

...
image, label = ...
logits = model(image)
arg_log = tf.argmax(logits, 1)
arg_lbl = tf.argmax(label, 1) # What must i change here?
cor_pre = tf.equal(arg_log, arg_lbl)
...

修改

我每次都会得到#34; 0"因为我得到索引!我改变了问题: 如何在TensorFlow中使用形状=(1,1)的Tensor上使用tf.argmax()? 至: 如何在TensorFlow中使用带有shape =(1,1)的标签张量的tf.equal()?

3 个答案:

答案 0 :(得分:1)

根据documentationtf.argmax()需要inputaxis等参数。

如果你的标签有形状[1,1],你期望从轴1的argmax得到什么?只有一个条目。

最有可能的是,您希望将标签与argmaxed结果进行比较。所以:

...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
cor_pre = tf.equal(arg_log, tf.cast(label, tf.int64))

答案 1 :(得分:0)

任何形状(1,1)的数组都只包含一个元素。那一个元素必须是数组中的最大元素。

答案 2 :(得分:0)

我找到了一个解决方案,将标签投放到int64,以便在tf.equal()中使用它。

...
image, label = ...
logits = model(image)
cor_pre = tf.equal(tf.argmax(logits, 1), tf.cast(label, tf.int64))
...