我在the official instructions之后在Tensorflow中实现了一个自定义运算符。我使用了T模板,以便我的操作符可以使用float类型和双倍类型的输入。这是我的.cc文件的负责人:
#include <stdio.h>
#include <math.h>
#include <cfloat>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "work_sharder.h"
using namespace tensorflow;
typedef Eigen::ThreadPoolDevice CPUDevice;
REGISTER_OP("NewOp")
.Attr("T: {float, double}")
.Attr("attr1: int")
.Attr("attr2: float")
.Input("input1: T")
.Input("input2: T")
.Output("output: T");
template <typename Device, typename T>
class NewOpOp : public OpKernel {
public:
explicit NewOpOp(OpKernelConstruction* context) : OpKernel(context) {
...
};
它正确编译,直到达到:
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<T>("T"), NewOpOp<CPUDevice, T>);
错误消息显示括号内的Ts未在此范围内声明,而T模板在第一个块的末尾明确定义! 如果我为此行更改此行:
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<float>("T"), NewOpOp<CPUDevice, float>);
编译错误消失了,但当然它强制输入浮动。
答案 0 :(得分:0)
我可以通过为每个op(float和double)
注册两个不同的内核来解决这个问题REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<float>("T"), NewOpOp<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<double>("T"), NewOpOp<CPUDevice, double>);
它有点难看,但它确实有效。