我正在尝试使用签名来计算tf.case上的梯度。
例如,假设我有一个case函数,它接受一批输入并根据输入的符号计算输出:
def case_fn(x):
N = tf.shape(x)[0]
positive_idx = tf.cast(tf.squeeze(tf.where(tf.squeeze(tf.math.greater(x, 0.)))),tf.int32)
negative_idx = tf.cast(tf.squeeze(tf.where(tf.squeeze(tf.math.less_equal(x, 0.)))),tf.int32)
def all_positive_case():
y_positive = x*2.
return y_positive
def all_negative_case():
y_negative = x-2.
return y_negative
def some_positive_some_negative_case():
x_positive = tf.gather(x, positive_idx)
x_negative = tf.gather(x, negative_idx)
y_positive = x_positive*2.
y_negative = x_negative-2.
y_positive = tf.scatter_nd(tf.expand_dims(positive_idx,1),y_positive,tf.stack([N,1]))
y_negative = tf.scatter_nd(tf.expand_dims(negative_idx,1),y_negative,tf.stack([N,1]))
return y_positive + y_negative
all_positive = tf.math.equal(tf.shape(negative_idx)[0], 0)
all_negative = tf.math.equal(tf.shape(positive_idx)[0], 0)
return tf.case([(all_positive, all_positive_case), (all_negative, all_negative_case)], default=some_positive_some_negative_case)
然后,我使用以下代码计算梯度:
trainable_variable = tf.Variable([[1.], [-1.], [2.], [-2.]])
@tf.function
def compute_grad():
with tf.GradientTape() as tape:
y = case_fn(trainable_variable)
grad = tape.gradient(y, trainable_variable)
return grad
print(compute_grad())
如果我不使用@tf.function
装饰器,它将返回正确的值IndexedSlices(indices=tf.Tensor([0, 2, 1, 3], shape=(4,), dtype=int32), values=tf.Tensor([[2.],[2.],[1.],[1.]], shape=(4, 1), dtype=float32), dense_shape=tf.Tensor([4 1], shape=(2,), dtype=int32))
。
但是,如果我使用@tf.function
装饰器,它会返回一个值错误
Traceback (most recent call last):
File "examples/case_gradient.py", line 102, in <module>
print(compute_grad())
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
result = self._call(*args, **kwds)
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
*args, **kwds))
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:
examples/case_gradient.py:99 compute_grad *
grad = tape.gradient(y, trainable_variable)
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/backprop.py:1029 gradient
unconnected_gradients=unconnected_gradients)
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/imperative_grad.py:77 imperative_grad
compat.as_str(unconnected_gradients.value))
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/backprop.py:141 _gradient_function
return grad_fn(mock_op, *out_grads)
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:121 _IfGrad
false_graph, grads, util.unique_grad_fn_name(false_graph.name))
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:381 _create_grad_func
func_graph=_CondGradFuncGraph(name, func_graph))
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:380 <lambda>
lambda: _grad_fn(func_graph, grads), [], {},
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:371 _grad_fn
src_graph=func_graph)
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:669 _GradientsHelper
lambda: grad_fn(op, *out_grads))
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:336 _MaybeCompile
return grad_fn() # Exit early
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:669 <lambda>
lambda: grad_fn(op, *out_grads))
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:183 _IfGrad
building_gradient=True,
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:219 _build_cond
_make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:652 _make_indexed_slices_indices_types_match
(current_index, len(branch_graphs[0].outputs)))
ValueError: Insufficient elements in branch_graphs[0].outputs.
Expected: 6
Actual: 3
我在这里想念什么?
答案 0 :(得分:1)
我已经检查了2.2.0-rc3
的最新版本,但没有看到此问题。
可能会在新版本中解决。