使用c ++

时间:2019-09-24 02:37:53

标签: c++ tensorflow

我构建了一个非常简单的自定义操作zero_out,并尝试使用c ++运行它,问题是如何调用此操作,没有关于它的文档。

自定义操作zero_out的代码

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;

REGISTER_OP("ZeroOut")
.Input("to_zero: float")
.Output("zeroed: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});

class ZeroOutOp : public OpKernel {
public:
    explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

    void Compute(OpKernelContext* context) override {
        const Tensor& input_tensor = context->input(0);
        auto input = input_tensor.flat<float>();
        Tensor* output_tensor = NULL;
        OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),                                                     &output_tensor));
        auto output_flat = output_tensor->flat<float>();
        const int N = input.size();
        for (int i = 1; i < N; i++)
            output_flat(i) = 0;
        if (N > 0) output_flat(0) = input(0);
    }
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

构建后,我得到了zero_out.so,然后使用c ++加载了lib

    TF_Status* status_load = TF_NewStatus();
    TF_Library* lib_handle = TF_LoadLibrary("libzero_out.so", status_load);
    TF_Code code = TF_GetCode(status_load);
    cout << "code: " << code << endl;  // output: 0

    TF_Buffer op_list_buf = TF_GetOpList(lib_handle);
    tensorflow::OpList op_list;
    op_list.ParseFromArray(op_list_buf.data, op_list_buf.length);
    cout << "oplist size = " << op_list.op_size() << endl;    // output: 1
    cout << "oplist name = " << op_list.op(0).name() << endl;    // output: ZeroOut

我认为应该已经注册了op,然后如何调用op函数,能否举一些例子或代码段?

0 个答案:

没有答案