我正在尝试使用tf.nn.weighted_cross_entropy_with_logits API,但我发现当重量不是1.0(1.0表示没有重量)时,我无法得到正确的结果。
import tensorflow as tf
import numpy as np
def my_binary_crossentropy_np(labels, output, weight=10.0):
"""
Weighted binary crossentropy between an output tensor
and a target tensor.
"""
# transform back to logits
epsilon = 1e-08
np.clip(output, epsilon, 1.0 - epsilon, out=output)
output = np.log(output / (1.0 - output))
# https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits
# l = 1 + (q - 1) * z
# (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
l = 1.0 + (weight - 1.0) * labels
loss1 = np.multiply(1.0 - labels, output)
loss2 = np.multiply(l, np.log(1.0 + np.exp(-abs(output))))
loss3 = np.maximum(-output, 0)
loss = loss1 + loss2 + loss3
return np.mean(loss)
def my_binary_crossentropy_tf(labels, output, weight=1.0):
"""
Weighted binary crossentropy between an output tensor
and a target tensor.
"""
epsilon = 1e-08
output = tf.clip_by_value(output, epsilon, 1.0 - epsilon)
output = tf.log(output / (1.0 - output))
# compute weighted loss
#loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=output)
loss = tf.nn.weighted_cross_entropy_with_logits(targets=labels, logits=output, pos_weight=weight)
return tf.reduce_mean(loss)
# generate random test data and random label
predict = np.random.rand(10, 8)
label = np.random.rand(10, 8)
label[label >= 0.5] = 1
label[label < 0.5] = 0
loss1 = my_binary_crossentropy_np(label, predict, 1.0)
print('loss1 = ', loss1)
loss1 = my_binary_crossentropy_np(label, predict, 10.0)
print('loss1 = ', loss1)
predict_tf = tf.convert_to_tensor(predict)
loss2 = my_binary_crossentropy_tf(label, predict_tf, 1.0)
loss2 = tf.Session().run(loss2)
print('loss2 = ', loss2)
loss2 = my_binary_crossentropy_tf(label, predict_tf, 10.0)
loss2 = tf.Session().run(loss2)
print('loss2 = ', loss2)
运行结果:
loss1 = 1.02193164517
loss1 = 1.96332399324
loss2 = 1.02193164517
loss2 = 4.80529539791
答案 0 :(得分:1)
my_binary_crossentropy_np的实现是错误的。 这是正确的:
l = (weight - 1.0) * labels + 1.0
loss1 = np.multiply(1.0 - labels, output)
loss2 = np.multiply(l, np.log(1.0 + np.exp(-abs(output))) + np.maximum(-output, 0))
loss = loss1 + loss2