我正在尝试下面但没有工作。
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("Auc")
.Input("predicts: T1")
.Input("labels: T2")
.Output("z: double")
.Attr("T1: {float, double}")
.Attr("T2: {int32, int64}")
.SetIsCommutative()
.Doc(R"doc(
Given preidicts and labels output it's auc
)doc");
class AucOp : public OpKernel {
public:
explicit AucOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& predicts_tensor = context->input(0);
const Tensor& labels_tensor = context->input(1);
auto predicts = predicts_tensor.flat<double>();
auto labels = labels_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
TensorShape output_shape;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));
output_tensor->flat<double>().setConstant(predicts(0) * labels(0));
}
};
REGISTER_KERNEL_BUILDER(Name("Auc").Device(DEVICE_CPU), AucOp);
test.py
predicts = tf.constant([0.8, 0.5, 0.12])
labels = tf.constant([-1, 1, 1])
output = tf.user_ops.auc(predicts, labels)
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print output.eval()
./ test.py I tensorflow / core / common_runtime / local_device.cc:40]本地设备intra op parallelism threads:8 I tensorflow / core / common_runtime / direct_session.cc:60]直接会话互操作并行线程:8 F ./tensorflow/core/public/tensor.h:453]检查失败:dtype()== DataTypeToEnum :: v()(1 vs. 2) 中止
答案 0 :(得分:2)
问题是Python程序中的predicts
张量类型为float
,并且您的操作注册将此作为predicts
输入的有效类型接受(从{{1}开始}可以是T1
或float
),但是 double
假定AucOp::Compute()
输入始终具有类型predicts
(在通话中)到double
)。当你要求不同类型的值时,predicts_tensor.flat<double>()
类不会转换张量中元素的类型,而是会引发致命错误。
有几种可能的解决方案:
为了让事情快速起作用,您可以将Python程序中tensorflow::Tensor
的类型更改为predicts
(这是Python前端中tf.float64
的同义词):
double
您可以从定义一个仅接受单一类型输入的简单操作开始:
predicts = tf.constant([0.8, 0.5, 0.12], dtype=tf.float64)
您可以在REGISTER_OP("Auc")
.Input("predicts: double")
.Input("labels: int32")
...;
方法中添加代码来测试输入类型并根据需要访问输入值。 (使用AucOp::Compute()
查找第i个输入的类型。
您可以定义模板化的类this->input_type(i)
,然后在AucOp<TPredict, TLabel>
调用中使用TypeConstraint<>
为预测和标签类型的四种有效组合中的每一种定义特化。这看起来像是:
REGISTER_KERNEL_BUILDER