我正在尝试修改简单Adding a New Op,以便它不会创建一个新的Tensor作为返回值,但它实际上会改变输入Tensor并返回它。我知道这是可能的,因为scatter
Op
正在做同样的事情,但是查看scatter
Op
源代码,我无法弄清楚到底要做什么缺乏C++
经验。
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
Tensor input_tensor = context->mutable_input(0, true);
auto input = input_tensor.flat<int32>();
// We always return the input ref.
context->forward_ref_input_to_ref_output(0, 0);
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
input(i) = 0;
}
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
如果我编译上面的代码并运行一个简单的Python
脚本来测试它,我会收到以下错误:
Python(14820,0x700003a7e000) malloc: *** error for object 0x7fd5c45a5a88: pointer being freed was not allocated
*** set a breakpoint in malloc_error_break to debug
我需要在代码中更改哪些内容才能满足我的需求?
答案 0 :(得分:2)
我认为你最好修改抓取输入和输出的过程。实际上根据你的REGISTER_OP,它不是参考输入,所以
context->mutable_input(0, true)
将是
context->input(0)
此外,设置输出将更改为
context->set_output(0, context->input(0))
我认为在设置输出后它会起作用。