在Tensorflow中为float32和float64实现新运算符时出错

时间:2018-06-13 12:35:01

标签: python c++ tensorflow

我在the official instructions之后在Tensorflow中实现了一个自定义运算符。我使用了T模板,以便我的操作符可以使用float类型和双倍类型的输入。这是我的.cc文件的负责人:

#include <stdio.h>
#include <math.h>
#include <cfloat>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "work_sharder.h"

using namespace tensorflow;
typedef Eigen::ThreadPoolDevice CPUDevice;

REGISTER_OP("NewOp")
    .Attr("T: {float, double}")
    .Attr("attr1: int")
    .Attr("attr2: float")
    .Input("input1: T")
    .Input("input2: T")
    .Output("output: T");



template <typename Device, typename T>
class NewOpOp : public OpKernel {
  public:
    explicit NewOpOp(OpKernelConstruction* context) : OpKernel(context) {
...
};

它正确编译,直到达到:

REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<T>("T"), NewOpOp<CPUDevice, T>);

错误消息显示括号内的Ts未在此范围内声明,而T模板在第一个块的末尾明确定义! 如果我为此行更改此行:

REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<float>("T"), NewOpOp<CPUDevice, float>);

编译错误消失了,但当然它强制输入浮动。

1 个答案:

答案 0 :(得分:0)

我可以通过为每个op(float和double)

注册两个不同的内核来解决这个问题
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<float>("T"), NewOpOp<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU).TypeConstraint<double>("T"), NewOpOp<CPUDevice, double>);

它有点难看,但它确实有效。