我尝试用in_top_k函数进行实验,以查看该函数到底在做什么。但是我发现了一些令人困惑的行为。
首先我编码如下
-Wall
然后它会产生以下错误:
import numpy as np
import tensorflow as tf
target = tf.constant(np.random.randint(2, size=30).reshape(30,-1), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
targetVal = target.eval()
predVal = pred.eval()
resultVal = result.eval()
然后我将代码更改为
ValueError: Shape must be rank 1 but is rank 2 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [30,1], [30,1], [].
但是现在错误变为
import numpy as np
import tensorflow as tf
target = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1).reshape(-1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
targetVal = target.eval()
predVal = pred.eval()
resultVal = result.eval()
那么输入应该是1级还是2级?
答案 0 :(得分:1)
对于in_top_k
,targets
需要排名1(类别索引),predictions
排名2(每个类别的分数)。可以轻松from the docs看到。
这意味着两条错误消息实际上每次都抱怨不同输入(第一次针对目标,第二次针对预测),但有趣的是根本没有在消息中提及…… ,以下代码段应更像它:
import numpy as np
import tensorflow as tf
target = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
targetVal = target.eval()
predVal = pred.eval()
resultVal = result.eval()
在这里,我们基本上结合了“两个片段中的最佳片段”:第一个片段的预测和第二个片段的目标。但是,按照我对文档的理解方式,即使对于二进制分类,我们也需要两个值来进行预测,每个类一个。所以像
import numpy as np
import tensorflow as tf
target = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")
pred = tf.concat((1-pred, pred), axis=1)
result = tf.nn.in_top_k(pred, target, 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
targetVal = target.eval()
predVal = pred.eval()
resultVal = result.eval()