如何向tensorflow添加操作(教程)

时间:2017-02-22 00:45:04

标签: python tensorflow

我想将自己的操作添加到tensorflow。 所以我读了https://www.tensorflow.org/extend/adding_an_op#use_the_op_in_python。 然后我尝试添加新操作,但它不起作用。 错误按摩:tensorflow.python.framework.errors_impl.NotFoundError:dlopen(zero_out.so,6):找不到图像

我把这个源代码放到tensorflow / tensorflow / core / user_ops命名为zero_out.cc。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();
    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output = output_tensor->flat<int32>();
"zero_out.cc" 35L, 1286C

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();
    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output = output_tensor->flat<int32>();
    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output(i) = 0;
    }
    // Preserve the first input value if possible.
    if (N > 0) output(0) = input(0);
  }
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

我只是用教程写同样的东西。 然后,我用相同的路径制作了BUILD文件。(tensorflow / tensorflow / core / user_ops)

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

然后,在终端(路径〜/ tensorflow)中,我用bazel构建命令。

bazel build --config opt //tensorflow/core/user_ops:zero_out.so

我试图运行我的代码。

import tensorflow as tf
zero_out_module = tf.load_op_library('zero_out.so')
with tf.Session(''):
  zero_out_module.zero_out([[1, 2], [3, 4]]).eval()

但只有按摩错误。

error massage: tensorflow.python.framework.errors_impl.NotFoundError: dlopen(zero_out.so, 6): image not found

我想知道什么是错的。

感谢阅读。

0 个答案:

没有答案