我试图了解TensorFlow 2.0中的CategoricalCrossentropy()
损失函数。当我使用
tf.keras.metrics.CategoricalCrossentropy(actual, pred)
我收到以下错误:
ValueError:具有多个元素的数组的真值是 暧昧。使用a.any()或a.all()
任何人都可以解释如何纠正它吗?
答案 0 :(得分:2)
从docs开始,计算交叉熵,使用
cce = tf.keras.metrics.CategoricalCrossentropy()
# cce.update_state(target, prediction)
cce.update_state([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
cce.result().numpy()
# 1.1769392
OTOH,如果您要计算交叉熵损耗,请改用tf.keras.losses.CategoricalCrossentropy
:
cce = tf.keras.metrics.CategoricalCrossentropy()
cce([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]).numpy()
# 1.1769392