Tensorflow:如何添加用户自定义操作接受两个1D vec张量并输出标量?

时间:2015-12-11 10:35:56

标签: tensorflow

我正在尝试下面但没有工作。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("Auc")
.Input("predicts: T1")
.Input("labels: T2")
.Output("z: double")
.Attr("T1: {float, double}")
.Attr("T2: {int32, int64}")
.SetIsCommutative()
.Doc(R"doc(
Given preidicts and labels output it's auc
)doc");

class AucOp : public OpKernel {
public:
explicit AucOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& predicts_tensor = context->input(0);
    const Tensor& labels_tensor = context->input(1);
    auto predicts = predicts_tensor.flat<double>();
    auto labels = labels_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    TensorShape output_shape;
    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));

    output_tensor->flat<double>().setConstant(predicts(0) * labels(0));
}
};

REGISTER_KERNEL_BUILDER(Name("Auc").Device(DEVICE_CPU), AucOp);


test.py 

predicts = tf.constant([0.8, 0.5, 0.12])
labels = tf.constant([-1, 1, 1])

output = tf.user_ops.auc(predicts, labels)

with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)

    print output.eval()

./ test.py I tensorflow / core / common_runtime / local_device.cc:40]本地设备intra op parallelism threads:8 I tensorflow / core / common_runtime / direct_session.cc:60]直接会话互操作并行线程:8 F ./tensorflow/core/public/tensor.h:453]检查失败:dtype()== DataTypeToEnum :: v()(1 vs. 2) 中止

1 个答案:

答案 0 :(得分:2)

问题是Python程序中的predicts张量类型为float,并且您的操作注册将此作为predicts输入的有效类型接受(从{{1}开始}可以是T1float),但是 double假定AucOp::Compute()输入始终具有类型predicts(在通话中)到double)。当你要求不同类型的值时,predicts_tensor.flat<double>()类不会转换张量中元素的类型,而是会引发致命错误。

有几种可能的解决方案:

  1. 为了让事情快速起作用,您可以将Python程序中tensorflow::Tensor的类型更改为predicts(这是Python前端中tf.float64的同义词):

    double
  2. 您可以从定义一个仅接受单一类型输入的简单操作开始:

    predicts = tf.constant([0.8, 0.5, 0.12], dtype=tf.float64)
    
  3. 您可以在REGISTER_OP("Auc") .Input("predicts: double") .Input("labels: int32") ...; 方法中添加代码来测试输入类型并根据需要访问输入值。 (使用AucOp::Compute()查找第i个输入的类型。

  4. 您可以定义模板化的类this->input_type(i),然后在AucOp<TPredict, TLabel>调用中使用TypeConstraint<>为预测和标签类型的四种有效组合中的每一种定义特化。这看起来像是:

    REGISTER_KERNEL_BUILDER