TensorFlow 2.0使用GradientTape在dtype = int32上返回意外输出

时间:2019-03-22 11:08:05

标签: tensorflow

以下代码应输出x = 2的y = x * x的梯度,即值为4.但是,使用TensorFlow 2.0.0-alpha0时,该代码将打印出None值。如下一个代码片段所示,当x的定义更改为使用tf.float32而不是tf.int32时,输出将更改为正确的值4。是否有任何文档将数据类型的要求阐明为:是GradientTape在这种情况下可以正常工作的浮点数吗?

print(tf.__version__)

x = tf.constant(2, dtype=tf.int32)

with tf.GradientTape() as tape:
  tape.watch(x)
  y = x ** 2
  print(tape.gradient(y, x))

输出:

2.0.0-alpha0
None

在下一个代码段中注意对tf.float32的更改:

print(tf.__version__)

x = tf.constant(2, dtype=tf.float32)

with tf.GradientTape() as tape:
  tape.watch(x)
  y = x ** 2
  print(tape.gradient(y, x))

输出:

2.0.0-alpha0
tf.Tensor(4.0, shape=(), dtype=float32)

1 个答案:

答案 0 :(得分:1)

原因是tf.gradient不会通过整数张量传播梯度。这个github问题中已经引用了这个:

https://github.com/tensorflow/tensorflow/issues/20524