我目前的代码:
// For Eigen::ThreadPoolDevice.
#define EIGEN_USE_THREADS 1
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
using namespace tensorflow;
REGISTER_OP("ArrayContainerCreate")
.Attr("T: type")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(Array container, random index access)doc");
REGISTER_OP("ArrayContainerGetSize")
.Input("handle: resource")
.Output("out: int32")
.SetShapeFn(shape_inference::ScalarShape)
;
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h
struct ArrayContainer : public ResourceBase {
ArrayContainer(const DataType& dtype) : dtype_(dtype) {}
string DebugString() override { return "ArrayContainer"; }
int64 MemoryUsed() const override { return 0; };
mutex mu_;
const DataType dtype_;
int32 get_size() {
mutex_lock l(mu_);
return (int32) 42;
}
};
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_op_kernel.h
class ArrayContainerCreateOp : public ResourceOpKernel<ArrayContainer> {
public:
explicit ArrayContainerCreateOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
}
private:
virtual bool IsCancellable() const { return false; }
virtual void Cancel() {}
Status CreateResource(ArrayContainer** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new ArrayContainer(dtype_);
if(*ret == nullptr)
return errors::ResourceExhausted("Failed to allocate");
return Status::OK();
}
Status VerifyResource(ArrayContainer* ar) override {
if(ar->dtype_ != dtype_)
return errors::InvalidArgument("Data type mismatch: expected ", DataTypeString(dtype_),
" but got ", DataTypeString(ar->dtype_), ".");
return Status::OK();
}
DataType dtype_;
};
REGISTER_KERNEL_BUILDER(Name("ArrayContainerCreate").Device(DEVICE_CPU), ArrayContainerCreateOp);
class ArrayContainerGetSizeOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* context) override {
ArrayContainer* ar;
OP_REQUIRES_OK(context, GetResourceFromContext(context, "handle", &ar));
core::ScopedUnref unref(ar);
int32 size = ar->get_size();
Tensor* tensor_size = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &tensor_size));
tensor_size->flat<int32>().setConstant(size);
}
};
REGISTER_KERNEL_BUILDER(Name("ArrayContainerGetSize").Device(DEVICE_CPU), ArrayContainerGetSizeOp);
我编译了。请注意,我首先遇到了一些undefined symbol: _ZN6google8protobuf8internal26fixed_address_empty_stringE
错误,但我通过添加这些额外的编译器标志解决了这个问题:
from google.protobuf.pyext import _message as msg
lib = msg.__file__
extra_compiler_flags = [
"-Xlinker", "-rpath", "-Xlinker", os.path.dirname(lib),
"-L", os.path.dirname(lib), "-l", ":" + os.path.basename(lib)]
我读到了here。
然后我通过tf.load_op_library
将其作为模块加载。
然后,我有这个Python代码:
handle = mod.array_container_create(T=tf.int32)
size = mod.array_container_get_size(handle=handle)
当我尝试评估size
时,我收到错误:
InvalidArgumentError (see above for traceback): Trying to access resource located in device 14ArrayContainer from device /job:localhost/replica:0/task:0/cpu:0
[[Node: ArrayContainerGetSize = ArrayContainerGetSize[_device="/job:localhost/replica:0/task:0/cpu:0"](array_container)]]
设备名称(14ArrayContainer
)似乎搞砸了。这是为什么?代码有什么问题?
对于更多测试,我在ArrayContainerCreateOp
:
ResourceHandle rhandle = MakeResourceHandle<ArrayContainer>(context, cinfo_.container(), cinfo_.name());
printf("created. device: %s\n", rhandle.device().c_str());
printf("container: %s\n", rhandle.container().c_str());
printf("name: %s\n", rhandle.name().c_str());
printf("actual device: %s\n", context->device()->attributes().name().c_str());
printf("actual name: %s\n", cinfo_.name().c_str());
这给了我输出:
created. device: 14ArrayContainer
container: 14ArrayContainer
name: 14ArrayContainer
actual device: /job:localhost/replica:0/task:0/cpu:0
actual name: _2_array_container
很明显,存在一些问题。
这看起来像是与protobuf搞砸了?也许我正在连接错误的lib?但我还没有找到要链接的lib。
(我还发布了有关此here的问题。)