有几条SO帖子显示了如何使用tf.custom_gradient
或tf.RegisterGradient
装饰器创建自定义渐变。
@ops.RegisterGradient("MyCustomGradient")
@tf.custom_gradient("MyCustomGradient")
但是您如何明确查找当前模型或图形使用了哪些梯度方法?查看ops._gradient_registry._registry
,当前有4个不同的Relu
键:
'Relu'
,'Relu6'
,'Relu6Grad'
,'ReluGrad'
。 如果我想使用自定义渐变色来代替'Relu'
,我一直在使用:
with g.gradient_override_map({'Relu': 'MyCustomGradient'}):
...
但是如何确定'Relu'
是否实际上是图形或模型当前正在使用的渐变方法?