我构建了一个非常简单的自定义操作zero_out
,并尝试使用c ++运行它,问题是如何调用此操作,没有关于它的文档。
自定义操作zero_out
的代码
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: float")
.Output("zeroed: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<float>();
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
auto output_flat = output_tensor->flat<float>();
const int N = input.size();
for (int i = 1; i < N; i++)
output_flat(i) = 0;
if (N > 0) output_flat(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
构建后,我得到了zero_out.so
,然后使用c ++加载了lib
TF_Status* status_load = TF_NewStatus();
TF_Library* lib_handle = TF_LoadLibrary("libzero_out.so", status_load);
TF_Code code = TF_GetCode(status_load);
cout << "code: " << code << endl; // output: 0
TF_Buffer op_list_buf = TF_GetOpList(lib_handle);
tensorflow::OpList op_list;
op_list.ParseFromArray(op_list_buf.data, op_list_buf.length);
cout << "oplist size = " << op_list.op_size() << endl; // output: 1
cout << "oplist name = " << op_list.op(0).name() << endl; // output: ZeroOut
我认为应该已经注册了op,然后如何调用op函数,能否举一些例子或代码段?