我正在将logit与循环中的标签进行比较:
for r in range(logits.shape[0]):
if labels[r] == np.argmax(logits[r]):
guessed += 1.0
其中labels
是一维整数标签数组,logits
是2D数组,第二维是标签的概率。
以上解决方案是效率不高的Python循环。应该有一个常用的numpy
或tensorflow
快捷方式来做到这一点。你能建议一个吗?
答案 0 :(得分:1)
您可以通过np.argmax(logits,axis=1)
一次获得所有的最大值。以下可以替换for循环以获取猜测的总数:
guessed = np.sum(labels == np.argmax(logits,axis=1))