我正在使用带有急切模式的张量流,并使用tf.GradientTape来记录梯度。但这会在tf.math.segment_max操作上引发错误。
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
inputs = np.random.randn(5)
inputs_tensor = tf.Variable(inputs, dtype=tf.float32)
seg_ids = tf.convert_to_tensor([0, 0, 0, 1, 1])
with tf.GradientTape() as tape:
y = tf.math.segment_max(inputs_tensor, seg_ids)
loss = (y - tf.convert_to_tensor([1., 1.], tf.float32)) ** 2
loss = tf.reduce_mean(loss, 0)
grads = tape.gradient(loss, inputs_tensor)
错误显示如下。
TypeError Traceback (most recent call last)
<ipython-input-2-b33dd17ed2d8> in <module>
6 loss = tf.reduce_mean(loss, 0)
7 print(y)
----> 8 grads = tape.gradient(loss, inputs_tensor)
/anaconda3/envs/my-rdkit-env/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py in gradient(self, target, sources, output_gradients, unconnected_gradients)
944 flat_sources,
945 output_gradients=output_gradients,
--> 946 unconnected_gradients=unconnected_gradients)
947
948 if not self._persistent:
/anaconda3/envs/my-rdkit-env/lib/python3.6/site-packages/tensorflow/python/eager/imperative_grad.py in imperative_grad(tape, target, sources, output_gradients, unconnected_gradients)
70 sources,
71 output_gradients,
---> 72 compat.as_str(unconnected_gradients.value))
/anaconda3/envs/my-rdkit-env/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py in _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, out_grads)
129 return [None] * num_inputs
130
--> 131 return grad_fn(mock_op, *out_grads)
132
133
/anaconda3/envs/my-rdkit-env/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py in _SegmentMaxGrad(op, grad)
277 def _SegmentMaxGrad(op, grad):
278 """Gradient for SegmentMax."""
--> 279 return _SegmentMinOrMaxGrad(op, grad)
280
281
/anaconda3/envs/my-rdkit-env/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py in _SegmentMinOrMaxGrad(op, grad)
257 print('op:', op)
258 print('op.outputs:', op.outputs)
--> 259 gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
260 is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
261 num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
TypeError: 'NoneType' object is not subscriptable
我认为“ SegmentMax”操作可能未正确注册,因为其“ op.outputs”为“无”。有什么解决办法吗?