我在tensorflow中实现了一个新的自定义c ++ op。在相应的操作内核的Compute函数中,调用了一些标准的ops(例如MatMul)。 主要的源代码是:
REGISTER_OP("NewOp")
.Input("input: int32")
.Output("output: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"
using namespace tensorflow;
using namespace tensorflow::ops;
class MyNewOp : public OpKernel {
public:
explicit MyNewOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
……
// Create an output tensor
……
Scope root = Scope::NewRootScope();
auto A = Const(root, { {35.f, 22.f}, {-10.f, 0.f} });
auto b = Const(root, { {30.f, 55.f} });
auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
std::vector<Tensor> results;
ClientSession session(root);
TF_CHECK_OK(session.Run({v}, &results));
// Set the output tensor according to the results of MatMul
……
}
};
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU), MyNewOp);
相应的Bazel BUILD文件是:
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
name = "MyNewOp.so",
srcs = ["mynewop.cc"],
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/core:tensorflow",
],
)
当我构建上述目标时,Bazel会返回错误:
tensorflow/cc:cc_ops cannot depend on tensorflow/core:framework
我该如何解决这个问题?我想知道我是否可以在新的自定义c ++ op中调用ternsorflow预定义操作?非常感谢你!
答案 0 :(得分:0)
您遇到的问题是因为您的自定义操作取决于此rule明确禁止的tensorflow/core:framework
:
disallowed_deps=[
clean_dep("//tensorflow/core:framework"),
clean_dep("//tensorflow/core:lib")
]
最好的方法是找到另一种解决方案。
如果您确实希望拥有此依赖关系,那么 hacky方式会重新实现tf_custom_op_library
规则而不会出现禁用依赖关系。
这可以通过以下方式完成:
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "clean_dep")
tf_cc_shared_object(
name = "MyNewOp.so",
srcs = ["mynewop.cc"],
copts = tf_copts(is_external=True),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/core:tensorflow",],
linkopts= select({
"//conditions:default": [
"-lm",
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
clean_dep("//tensorflow:darwin"): [],
}),
)
工作正常:
Target //tensorflow/user_ops:MyNewOp.so up-to-date:
bazel-bin/tensorflow/user_ops/MyNewOp.so
INFO: Elapsed time: 46.399s, Critical Path: 19.71s
INFO: 397 processes, local.
INFO: Build completed successfully, 400 total actions