为什么Tensorflow找不到我自定义Op的GPU内核?

时间:2018-01-31 21:59:30

标签: c++ tensorflow cuda

我已经按照Tensorflow网站上的Adding a New Op示例添加GPU内核到我的自定义操作系统。它编译得很好,但是当我尝试使用GPU内核时,Tensorflow给了我这个错误:

InvalidArgumentError (see above for traceback): Cannot assign a device to node 'b': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available.

有人能指出我出错的地方吗?我正在使用上面链接的页面上给出的确切示例代码,以及以下修复/更改(可在各种其他论坛和StackOverflow页面上找到) ):

  • #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"添加到kernel_example.cu.cc
  • "example.h""kernel_example.h"
  • 中将kernel_example.cu.cc更改为kernel_example.cc
  • //添加到#endif // KERNEL_EXAMPLE_H_
  • 中的第kernel_example.h
  • kernel_example.h中的部分特化修正为:

    template <typename T> struct ExampleFunctor<Eigen::GpuDevice, T> { ... };

  • kernel_example.cc注册了操作:

    REGISTER_OP("Example").Attr("T: {float, int32} = DT_FLOAT").Input("input: T").Output("output: T"));

  • .cu.cc文件的最后几行更正为template struct ExampleFunctor<GPUDevice, float>;struct缺失)

生成文件:

TF_LIB := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())' 2>/dev/null)
TF_INC := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())' 2>/dev/null)

CUDA_LIB=/z/sw/packages/cuda/8.0/lib64

all: kernel_example.cu.cc kernel_example.cc
    nvcc -std=c++11 -c -o kernel_example.cu.o kernel_example.cu.cc -I $(TF_INC) -I$(TF_INC)/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -D_MWAITXINTRIN_H_INCLUDED --expt-relaxed-constexpr

    g++ -std=c++11 -shared -o kernel_example.so kernel_example.cc kernel_example.cu.o -I $(TF_INC) -I$(TF_INC)/external/nsync/public -fPIC -L$(CUDA_LIB) -lcudart -L$(TF_LIB) -D_GLIBCXX_USE_CXX11_ABI=0 -D GOOGLE_CUDA=1

编辑:如我在下面的回答中所述,问题产生于g++来电缺少-D GOOGLE_CUDA=1

测试代码:

import tensorflow as tf
example_lib = tf.load_op_library('kernel_example.so')

with tf.device('/gpu:0'):
    a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='a')
    b = example_lib.example(a, name='b')
sess = tf.Session()
print(sess.run(b))

1 个答案:

答案 0 :(得分:2)

Tensorflow文档与正常情况一样糟糕且不完整。 g++命令缺失-D GOOGLE_CUDA=1。我会编辑问题以反映这一点。