在张量流中定义新的op并计算梯度

时间:2018-06-06 12:08:37

标签: python tensorflow

我是tensorflow的新人。

我尝试使用非常简单的公式在Tensorflow中添加新的操作。

但我不知道如何在代码中定义op'渐变(python或C)。

我已经研究了

中的样本

https://www.tensorflow.org/extend/adding_an_op

https://github.com/davidstutz/tensorflow-cpp-op-example/blob/master/_inner_product_grad.py

然后我无法理解。

示例:F(y,x)= y ^ 2 + x ^ 3

下面我已经在C代码中添加了op。对我来说工作正常。

    #include "tensorflow/core/framework/op.h"
    #include "tensorflow/core/framework/shape_inference.h"
    #include <math.h>
    using namespace tensorflow;

    REGISTER_OP("FunOut")
        .Input("to_funy: float")
        .Input("to_funx: float")
        .Output("funed: float")
        .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
          c->set_output(0, c->input(0));
          return Status::OK();
        });

    #include "tensorflow/core/framework/op_kernel.h"

    using namespace tensorflow;

    class FunOutOp : public OpKernel {
     public:
      explicit FunOutOp(OpKernelConstruction* context) : OpKernel(context) {}

      void Compute(OpKernelContext* context) override {
        const Tensor& input_tensorY = context->input(0);
        auto inputY = input_tensorY.flat<float>();
        auto input_shapeY = input_tensorY.shape();

        const Tensor& input_tensorX = context->input(1);
        auto inputX = input_tensorX.flat<float>();
        auto input_shapeX = input_tensorX.shape();

        const int All = input_shapeX.dim_size(0);
        // Create an output tensor
        Tensor* output_tensor = NULL;
        OP_REQUIRES_OK(context, context->allocate_output(0,input_tensorY.shape(), &output_tensor));
        auto output_flat = output_tensor->flat<float>();
        for(int i = 0; i<All;i++){
            output_flat(i) = inputY(i)*inputY(i)+inputX(i)*inputX(i)*inputX(i);
        }

      }
    };

    REGISTER_KERNEL_BUILDER(Name("FunOut").Device(DEVICE_CPU), FunOutOp)

应该是

        import tensorflow as tf
        from tensorflow.python.framework import ops
        ........
        @ops.RegisterGradient("Fun")
        def _fun_grad(op, grad):

        ......
        ......

        return [,]

有人可以帮助我吗?并解释。非常感谢。

据我所知,dF / dx = 3 * x,dF / dy = 2 * y,......

0 个答案:

没有答案