如何在自定义Op中改变Tensorflow变量?

时间:2018-03-02 17:33:43

标签: python tensorflow

我正在尝试修改简单Adding a New Op,以便它不会创建一个新的Tensor作为返回值,但它实际上会改变输入Tensor并返回它。我知道这是可能的,因为scatter Op正在做同样的事情,但是查看scatter Op源代码,我无法弄清楚到底要做什么缺乏C++经验。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;


REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });



#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {

    // Grab the input tensor
    Tensor input_tensor = context->mutable_input(0, true);
    auto input = input_tensor.flat<int32>();

    // We always return the input ref.
    context->forward_ref_input_to_ref_output(0, 0);

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
            input(i) = 0;
    }
  }
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

如果我编译上面的代码并运行一个简单的Python脚本来测试它,我会收到以下错误:

Python(14820,0x700003a7e000) malloc: *** error for object 0x7fd5c45a5a88: pointer being freed was not allocated
*** set a breakpoint in malloc_error_break to debug

我需要在代码中更改哪些内容才能满足我的需求?

1 个答案:

答案 0 :(得分:2)

我认为你最好修改抓取输入和输出的过程。实际上根据你的REGISTER_OP,它不是参考输入,所以

context->mutable_input(0, true)

将是

context->input(0)

此外,设置输出将更改为

context->set_output(0, context->input(0))

我认为在设置输出后它会起作用。