如何在张量流中为此操作定义渐变函数?

时间:2017-08-24 10:28:12

标签: tensorflow

我正在尝试在tensorflow中实现一个新的op,有三个输入张量和两个输出张量,如下所示(由于与此问题无关而忽略了一些代码):

REGISTER_OP("MyNewFuncOp")
    .Attr("alpha: float = 1.0")
    .Attr("beta: float = 1.0")
    .Attr("debug: bool = false")
    .Input("input1: float32")
    .Input("input2: float32")
    .Input("input3: float32")
    .Output("output1: float32")
    .Output("output2: float32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      c->set_output(1, c->input(1));
      return Status::OK();
    });

class MyNewFuncOp : public OpKernel {
 public:
  explicit MyNewFuncOp(OpKernelConstruction* context) : OpKernel(context) {
   // some staffs
   ...
  }
 void Compute(OpKernelContext* context) override {
   // some staffs
   ...
   Tensor* output_tensor1 = NULL;
   OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({height, width, channels}),
                                                 &output_tensor1));
   Tensor* output_tensor2 = NULL;
   OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({height, width, channels}),
                                                 &output_tensor2));

 // some other staffs
 ...
 }

梯度注册如下:

import tensorflow as tf
from tensorflow.python.framework import ops
custom_module = tf.load_op_library('MyNewFunc.so')

@ops.RegisterGradient("MyNewFun")
def _MyNewFun_grad(op, grad1, grad2):
    input3 = op.inputs[2]
    return [grad1, grad2, tf.zeros_like(input3)]

但是这个渐变函数在我的实验中似乎是错误的,它可以运行正常,但是在构建训练运算符中运行grads = opt.compute_gradients(total_loss)后,此操作将产生错误的结果。但是这个操作可以运行正常并且还在评估状态中产生正确的结果(没有训练,即没有梯度计算)。所以我意识到这个梯度函数可能是错误的。我在官方文件https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/extend/adding_an_op.md#implement-gradient中阅读了此页面。事实上,在这个操作中,我只想要将两个输出张量(output1output2)的两个输出张量(input1input2)直接复制(反向传播)到前两个输入张量(即{{{} 1}}和jojo *alpha = malloc(sizeof(int)*100); )。

如何为此操作实现正确的渐变功能?感谢。

1 个答案:

答案 0 :(得分:0)

终于得到了答案:

import tensorflow as tf
from tensorflow.python.framework import ops
custom_module = tf.load_op_library('MyNewFunc.so')

@ops.RegisterGradient("MyNewFun")
def _MyNewFun_grad(_, *grads):
    return list(grads) + [None]