TensorFlow为Attr类型“张量”接受哪些Python类型?

时间:2017-05-24 20:14:45

标签: python tensorflow

我在C ++中定义了一个新的Op,它接受tensor类型的单个属性,大致跟随these instructions。操作码的剥离版本如下:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

REGISTER_OP("DoStuff")
    .Attr("attr: tensor = { dtype: DT_FLOAT }")
    .Input("in: float")
    .Output("out: float");

class DoStuffOp : public OpKernel {
public:
    explicit DoStuffOp(OpKernelConstruction *context) : OpKernel(context) {
        OP_REQUIRES_OK(context, context->GetAttr("attr", &attr_));
        // ...
    }

    void Compute(OpKernelContext *context) override {
        // ...
    }

private:
    Tensor attr_;
};

REGISTER_KERNEL_BUILDER(Name("DoStuff").Device(DEVICE_CPU), DoStuffOp);

我可以将Op编译成.so文件。但是,我无法弄清楚如何成功传递attr的值。当我在Python中运行以下内容时:

import tensorflow as tf
dostufflib = tf.load_op_library('build/do_stuff.so')
sess = tf.InteractiveSession()

A = [[1.0, 2.0, 3.0],
     [4.0, 5.0, 6.0]]
X = tf.Variable(tf.constant(1.0))

Y = dostufflib.do_stuff(X, A)

我得到TypeError: Don't know how to convert [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] to a TensorProto for argument 'attr'。我所做的一切似乎都不满足类型转换:listnumpy数组,tf.Tensortf.Variable等。如何将Python变量作为张量属性传递给Op?

1 个答案:

答案 0 :(得分:2)

经过更多的搜索后,我找到了tf.contrib.util.make_tensor_proto,这是一个将python标量,python列表,numpy ndarray或numpy标量转换为tf.TensorProto对象的函数。以下作品:

A = tf.contrib.util.make_tensor_proto([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
X = tf.Variable(tf.constant(1.0))

Y = dostufflib.do_stuff(X, A)