使用C ++时如何加载自定义操作库?

时间:2019-09-09 15:50:00

标签: c++ tensorflow

我已经用bazel构建了一个非常简单的自定义操作zero_out.dll,在使用python时可以使用。

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.dll')
with tf.Session(''):
  zero_out_module.zero_out([[1, 2], [3, 4]]).eval()

但是我必须使用C ++运行推论,是否有任何c ++ api具有与tf.load_op_library类似的功能,因为似乎在tf.load_op_library中做了很多注册工作,TF还没有对应的c ++ API?

1 个答案:

答案 0 :(得分:2)

尽管在C ++中似乎没有针对该API的公共API,但库加载功能在TensorFlow API for Ctf.load_library使用的API)中公开。没有“不错”的文档,但是您可以在c/c_api.h中找到它们:

// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels

// TF_Library holds information about dynamically loaded TensorFlow plugins.
typedef struct TF_Library TF_Library;

// Load the library specified by library_filename and register the ops and
// kernels present in that library.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, place OK in status and return the newly created library handle.
// The caller owns the library handle.
//
// On failure, place an error status in status and return NULL.
TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename,
                                                 TF_Status* status);

// Get the OpList of OpDefs defined in the library pointed by lib_handle.
//
// Returns a TF_Buffer. The memory pointed to by the result is owned by
// lib_handle. The data in the buffer will be the serialized OpList proto for
// ops defined in the library.
TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);

// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);

这些函数实际上确实调用C ++代码(请参见c/c_api.cc中的源代码)。但是,在core/framework/load_library.cc中定义的被调用函数没有要包含的标头。他们在c/c_api.cc中使用的C ++代码中使用它的解决方法是自己声明函数,并链接TensorFlow库。

namespace tensorflow {
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
                   const void** buf, size_t* len);
}

据我所知,没有API可以卸载该库。 C API仅允许您删除库句柄对象。这只是通过释放指针来完成的,但是如果要避免麻烦,您可能应该使用core/platform/mem.h中声明的TensorFlow tensorflow::port:free给定的释放函数。同样,如果您不能或不希望包含该函数,则可以自己声明该函数,并且该函数也应正常工作。

namespace tensorflow {
namespace port {
void Free(void* ptr);
}
}