如何从C ++代码运行自定义GPU tensorflow :: op?

时间:2018-06-11 13:59:34

标签: c++ tensorflow

我按照这些例子在TensorFlow中编写自定义操作:
Adding a New Op
cuda_op_kernel
将功能更改为我需要执行的操作 但所有示例都是Python代码中的测试 我需要从c ++代码运行我的操作,我该怎么做?

1 个答案:

答案 0 :(得分:0)

这个简单的例子展示了使用C++ API构建和执行图表:

// tensorflow/cc/example/example.cc

#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"

int main() {
  using namespace tensorflow;
  using namespace tensorflow::ops;
  Scope root = Scope::NewRootScope();
  // Matrix A = [3 2; -1 0]
  auto A = Const(root, { {3.f, 2.f}, {-1.f, 0.f} });
  // Vector b = [3 5]
  auto b = Const(root, { {3.f, 5.f} });
  // v = Ab^T
  auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true)); // <- in your case you should put here your custom Op
  std::vector<Tensor> outputs;
  ClientSession session(root);
  // Run and fetch v
  TF_CHECK_OK(session.Run({v}, &outputs));
  // Expect outputs[0] == [19; -3]
  LOG(INFO) << outputs[0].matrix<float>();
  return 0;
}

与Python相同,首先需要在范围内构建计算图,在这种情况下,只有一个矩阵乘法,其终点在v。然后,您需要为范围打开一个新会话(session),然后在图表上运行它。在这种情况下,没有提要字典,但在页面的末尾有一个关于如何提供值的示例:

Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32);
// [3 3; 3 3]
auto b = Const(root, 3, {2, 2});
auto c = Add(root, a, b);
ClientSession session(root);
std::vector<Tensor> outputs;

// Feed a <- [1 2; 3 4]
session.Run({ {a, { {1, 2}, {3, 4} } } }, {c}, &outputs);
// outputs[0] == [4 5; 6 7]

此处报告的所有代码段均来自TensorFlow的C ++ API指南

如果要调用自定义OP,则必须使用几乎相同的代码。我在this repository中有一个自定义操作,我将用作示例代码。 OP已经注册:

REGISTER_OP("ZeroOut")
  .Input("to_zero: int32")
  .Output("zeroed: int32")
  .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
    c->set_output(0, c->input(0));
    return Status::OK();
  });

并且Op被定义为cuda file中的Cuda内核。要启动Op,我必须(再次)创建一个新的计算图,注册我的操作,打开一个会话并使其从我的代码运行:

Scope root = Scope::NewRootScope();
// Matrix A = [3 2; -1 0]
auto A = Const(root, { {3.f, 2.f}, {-1.f, 0.f} });
auto v = ZeroOut(root.WithOpName("v"), A); 
std::vector<Tensor> outputs;
ClientSession session(root);
// Run and fetch v
TF_CHECK_OK(session.Run({v}, &outputs));
LOG(INFO) << outputs[0].matrix<float>();