在急切执行模式下将Tensor Attr传递给自定义操作的问题

时间:2019-06-29 18:43:50

标签: python tensorflow

我正在C ++中定义一个新的自定义Op,它接受一个张量类型的属性和一个输入张量变量。操作码的摘录版本如下:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

REGISTER_OP("DoStuff")
    .Attr("attr: tensor = { dtype: DT_FLOAT }")
    .Input("in: float")
    .Output("out: float");

class DoStuffOp : public OpKernel {
public:
    explicit DoStuffOp(OpKernelConstruction *context) : OpKernel(context) {
        OP_REQUIRES_OK(context, context->GetAttr("attr", &attr_));
        // ...
    }

    void Compute(OpKernelContext *context) override {
        // ...
    }

private:
    Tensor attr_;
};

REGISTER_KERNEL_BUILDER(Name("DoStuff").Device(DEVICE_CPU), DoStuffOp);

我可以将Op编译成.so文件。现在,将运行以下代码。

import tensorflow as tf
dostufflib = tf.load_op_library('build/do_stuff.so')
sess = tf.InteractiveSession() 

sample_in = np.random.rand(3,3)
sample_in_t = tf.convert_to_tensor(sample_in, dtype=np.float32)
sample_atrr = np.zeros([3,3], dtype=np.float32)
sample_attr_t = tf.contrib.util.make_tensor_proto(sample_atrr)

Y = dostufflib.do_stuff(in=sample_in_t, attr=sample_attr_t)

但是,如果我尝试使用积极的执行模式,即

import tensorflow as tf
tf.compat.v1.enable_eager_execution()
dostufflib = tf.load_op_library('build/do_stuff.so')

sample_in = np.random.rand(3,3)
sample_in_t = tf.convert_to_tensor(sample_in, dtype=np.float32)
sample_atrr = np.zeros([3,3], dtype=np.float32)
sample_attr_t = tf.contrib.util.make_tensor_proto(sample_atrr)

Y = dostufflib.do_stuff(in=sample_in_t, attr=sample_attr_t)

出现以下错误,

tensorflow.python.framework.errors_impl.UnimplementedError: Attr sample_locs has unhandled type 6

0 个答案:

没有答案