pybind11如何将自定义类型转换程序用于简单的示例类

时间:2019-08-22 21:40:04

标签: python c++ pybind11

动机

我目前正在尝试在pybind11和python中使用自定义类。其背后的动机是在c ++中使用经过python训练的分类器。

有一些可行的示例,例如在官方文档https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html中,或在tdegeus https://github.com/tdegeus/pybind11_examples/blob/master/09_numpy_cpp-custom-matrix/pybind_matrix.h中的漂亮示例中 但是,我仍然很难将其转换为带有自定义类的简单示例。

下面是一个最小的工作示例,它在C ++中具有pybind11的包装函数,并在python中使用了它。

#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;

//A custom data class
class class_DATA{
    public: 
        std::vector<int> a;
};

//the actual function
class_DATA func_test(class_DATA data){
    std::vector<int> a = {1,2};
    data.a = a;
    return data; 
}

//the wrapper function that communicates with python
py::object wrapper(py::object data_py){
    class_DATA data_cpp = class_DATA();

    py::object a    = data_py.attr("a");
    data_cpp.a      = a.cast<std::vector<int>>();

    class_DATA data_result_cpp = func_test(data_cpp);

    py::object data_result_py   = data_py;
    data_result_py.attr("a")    = py::cast(data_result_cpp.a);

    return data_result_py;
}

//defining the modules
PYBIND11_MODULE(TEST,m){
  m.doc() = "pybind11 example plugin";
  //the test function
  m.def("func_test",  &wrapper);
 //the custom class
 py::class_<class_DATA>(m, "class_DATA", py::dynamic_attr())
  .def(py::init<>())    
  .def_readwrite("a", &class_DATA::a);
}

from build.Debug.TEST import func_test 

//the custom data class in python
class class_DATA_py:
    def __init__(self):
        self.a   = [0] 
test_py     = class_DATA_py()
data_return = func_test(test_py)

输出输入为data.a = 0,输出data_return.a = [1,2]。

问题

如何用自定义脚轮替换包装函数调用?最可能具有以下概述的一般形状(?)

namespace pybind11 { namespace detail {
    template <> struct type_caster<class_DATA> : public type_caster_base<class_DATA> {
        using base = type_caster_base<class_DATA>;
        public:
            PYBIND11_TYPE_CASTER(class_DATA, _("class_DATA"));
            // Conversion part 1 (Python->C++): 
            bool load(py::handle src, bool convert){
                PyObject *source = src.ptr();
                //what to do here?
                return true;
            }
            // Conversion part 2 (C++ -> Python): 
            static py::handle cast(class_DATA src, py::return_value_policy policy, py::handle parent){
                //what to do here?
                return base::cast(src, policy, parent);
            }
    };
}}

c ++不是我的强项,所以当我插入此部分时,python反复崩溃而没有任何错误消息,因此我们将不胜感激。

0 个答案:

没有答案