TensorFlow weighted_cross_entropy_with_logits产生错误的结果

时间:2018-01-02 17:17:42

标签: python tensorflow

我正在尝试使用tf.nn.weighted_cross_entropy_with_logits API,但我发现当重量不是1.0(1.0表示没有重量)时,我无法得到正确的结果。

import tensorflow as tf
import numpy as np

def my_binary_crossentropy_np(labels, output, weight=10.0):
  """
  Weighted binary crossentropy between an output tensor 
  and a target tensor.
  """
  # transform back to logits
  epsilon = 1e-08
  np.clip(output, epsilon, 1.0 - epsilon, out=output)
  output = np.log(output / (1.0 - output))

  # https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits 
  # l = 1 + (q - 1) * z
  # (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
  l = 1.0 + (weight - 1.0) * labels
  loss1 = np.multiply(1.0 - labels, output)
  loss2 = np.multiply(l, np.log(1.0 + np.exp(-abs(output))))
  loss3 = np.maximum(-output, 0)
  loss = loss1 + loss2 + loss3

  return np.mean(loss)


def my_binary_crossentropy_tf(labels, output, weight=1.0):
  """
  Weighted binary crossentropy between an output tensor 
  and a target tensor.
  """
  epsilon = 1e-08
  output = tf.clip_by_value(output, epsilon, 1.0 - epsilon)
  output = tf.log(output / (1.0 - output))

  # compute weighted loss
  #loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=output)
  loss = tf.nn.weighted_cross_entropy_with_logits(targets=labels, logits=output, pos_weight=weight)
  return tf.reduce_mean(loss)


# generate random test data and random label
predict = np.random.rand(10, 8)

label = np.random.rand(10, 8)
label[label >= 0.5] = 1
label[label < 0.5] = 0


loss1 = my_binary_crossentropy_np(label, predict, 1.0)
print('loss1 = ', loss1)

loss1 = my_binary_crossentropy_np(label, predict, 10.0)
print('loss1 = ', loss1)


predict_tf = tf.convert_to_tensor(predict)
loss2 = my_binary_crossentropy_tf(label, predict_tf, 1.0)
loss2 = tf.Session().run(loss2)
print('loss2 = ', loss2)

loss2 = my_binary_crossentropy_tf(label, predict_tf, 10.0)
loss2 = tf.Session().run(loss2)
print('loss2 = ', loss2)

运行结果:

loss1 = 1.02193164517
loss1 = 1.96332399324

loss2 = 1.02193164517
loss2 = 4.80529539791

1 个答案:

答案 0 :(得分:1)

my_binary_crossentropy_np的实现是错误的。 这是正确的:

l = (weight - 1.0) * labels + 1.0
  loss1 = np.multiply(1.0 - labels, output)
  loss2 = np.multiply(l, np.log(1.0 + np.exp(-abs(output))) + np.maximum(-output, 0))
  loss = loss1 + loss2