如何在Windows上构建Tensorflow自定义操作

时间:2019-07-01 17:48:00

标签: python windows tensorflow

我一直在尝试使用Windows版本的TensorFlow进行自定义操作。我遵循了指南:

https://www.tensorflow.org/guide/extend/op#build_a_pip_package_for_your_custom_op

我能够正确构建,但是当我尝试对其进行测试时,似乎仍然无法识别我的自定义操作。

似乎可以识别内置的自定义操作tf.user_ops.my_fact, 但无法识别我创建的user_opssquared_out

这是我逐步执行的操作:

  1. 克隆存储库:
git clone https://github.com/tensorflow/tensorflow.git 
  1. 我将squared_out.cc的C ++实现放入tensorflow \ tensorflow \ core \ user_ops
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("SquaredOut")
    .Input("to_square: int32")
    .Output("squared: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
        c->set_output(0, c->input(0));
        return Status::OK();
    });


class SquaredOutOp : public OpKernel {

    public:
        explicit SquaredOutOp(OpKernelConstruction* context) : OpKernel(context) {}

        void Compute(OpKernelContext* context) override {

            const Tensor& input_tensor = context->input(0);
            auto input = input_tensor.flat<int32>();

            Tensor* output_tensor = NULL;
            OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));

            auto output_flat = output_tensor->flat<int32>();

            const int N = input.size();

            for(int i = 0; i < N; ++i){

                output_flat(i) = input(i) * input(i);

            }
        }
};

REGISTER_KERNEL_BUILDER(Name("SquaredOut").Device(DEVICE_CPU), SquaredOutOp);
  1. 之后,我使用bazel配置并构建了tensorflow库:
python ./configure.py
bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin\tensorflow\tools\pip_package\build_pip_package C:/tmp/tensorflow_pkg
  1. 然后我安装了bazel生成的.whl:
pip install C:/tmp/tensorflow_pkg/tensorflow-version-cp35-cp35m-win_amd64.whl
  1. 最后,我尝试测试从squared_out新导入的user_ops
import tensorflow as tf
#This works fine with build in fact function
tf.user_ops.fact

此返回:

<function my_fact at 0x0000020CE5601048>

但是,我的自定义操作无效:

tf.user_ops.squared_out

AttributeError: module 'tensorflow._api.v1.user_ops' has no attribute 'squared_out'

0 个答案:

没有答案