我正在使用各自的渐变在Tensorflow中创建几个自定义操作。一切都可以单独很好地工作,但是当我的两个包用相同的名称定义两个不同的操作(不同的输入)时,我面临一个问题。
为了简化我的问题,假设在两个软件包中定义了一个matmul
操作。可以很容易地在以下代码中使用它:
import tensorflow as tf
my_ops_a = tf.load_op_library('libpackage_a.so')
my_ops_b = tf.load_op_library('libpackage_b.so')
x, y = tf.random.uniform(10,10), tf.random.uniform(10,10)
my_ops_a.matmul(x, y)
my_ops_b.matmul(x, y)
其梯度可以通过以下方式通知Tensorflow:
from tensorflow.python.framework import ops as tf_ops
@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
return my_ops_a.mat_mul_grad(grad, op.inputs[0], op.inputs[1])
@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
return my_ops_b.mat_mul_grad(grad, op.inputs[0], op.inputs[1])
但是,@tf_ops.RegisterGradient
无法识别我所指的matmul
。
实际上,当我尝试运行通知的代码时,出现以下错误:
KeyError: "Registering two gradient with name 'Matmul'! (Previous registration was in <module> ...)
如何通知Tensorflow我所指的是特定软件包的操作?
谢谢。