我正在做实验,通过实现tf.cast
和tf.py_func
为tf.RegisterGradient
函数提供渐变,如下所示:
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
def castInt32(x, name=None):
with ops.name_scope(name, "CastInt32", [x]) as name:
sqr_x = py_func(np.int32,
[x],
[tf.int32],
name=name,
grad=_castInt32) # <-- here's the call to the gradient
return sqr_x[0]
# Actual gradient:
def _castInt32(op, grad):
t = [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]
src_type = op.inputs[0].dtype.base_dtype
dst_type = grad.dtype.base_dtype
if src_type in t and dst_type in t:
return math_ops.cast(grad, src_type)
else:
return None
with tf.Session() as session:
#session.run(init)
x = tf.constant([20., 1.])
y = castInt32(x)
z = tf.gradients(y,x)[0]
print z
它输出[1. , 1.]
作为输出。但是,当我使用这些功能来训练简单的网络时,执行optimizer.compute_gradient
时会出现错误。
错误显示如下:
ValueError: Invalid type tf.int32 for Cast:0, expected: [tf.float32, tf.float64, tf.float16, tf.bfloat16].
如果是因为渐变的数据类型,则该数据类型已经在tf.float32
中。
有人可以帮助我了解这种情况吗?预先感谢。