tf.divide()不支持使用int类型输入的渐变

时间:2017-12-13 06:36:28

标签: tensorflow

我在Ubuntu14.04上安装了tensorflow-1.4.1,我想用以下方式创建名为Dice Loss的自定义丢失函数:

>>> # true_labels and pred_labels are in type int64
>>> a = tf.square(true_labels)
>>> b = tf.square(pred_labels)
>>> upper = tf.multiply(true_labels, pred_labels)
>>> lower = tf.add(a, b, upper)
>>> iou = tf.divide(upper, lower)
>>> loss = tf.subtract(tf.constant(1, dtype=tf.float64), iou)

但是,当输入类型为int32或int64时,我发现tf.divide()不提供渐变。我在下面的更一般情况下检查了这个:

>>> import tensorflow as tf
>>> a = tf.constant(1, dtype=tf.int64)
>>> b = tf.constant(2, dtype=tf.int64)
>>> c = tf.divide(a, b)
>>> g = tf.gradients(c, a)
>>> print(g)

结果是:

>>> [None]

当我将dtypea更改为int64时,它可以返回正确的结果float32

任何人都可以帮我这个吗?非常感谢!

1 个答案:

答案 0 :(得分:0)

您可以使用upperlowertf.cast投射到tf.float32类型。分界线可以是:

upper = tf.cast(upper, tf.float32)
lower = tf.cast(lower, tf.float32)
iou = tf.divide(upper, lower)