如何使用TensorFlow 2.0计算CategoricalCrossentropy?

时间:2019-06-11 21:57:20

标签: python numpy tensorflow tensorflow2.0

我试图了解TensorFlow 2.0中的CategoricalCrossentropy()损失函数。当我使用

tf.keras.metrics.CategoricalCrossentropy(actual, pred)

我收到以下错误:

  

ValueError:具有多个元素的数组的真值是   暧昧。使用a.any()或a.all()

任何人都可以解释如何纠正它吗?

1 个答案:

答案 0 :(得分:2)

docs开始,计算交叉熵,使用

cce = tf.keras.metrics.CategoricalCrossentropy()
# cce.update_state(target, prediction)
cce.update_state([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])

cce.result().numpy()
# 1.1769392

OTOH,如果您要计算交叉熵损耗,请改用tf.keras.losses.CategoricalCrossentropy

cce = tf.keras.metrics.CategoricalCrossentropy()
cce([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]).numpy()
# 1.1769392