用于类STL矢量类的鲁棒型脚轮

时间:2018-07-16 08:22:47

标签: pybind11

我有一个非常类似于STL-vector的类(区别对于pybind11类型的caster并不重要,因此在这里我将忽略它们)。我已经为此类编写了类型转换程序。下面是我的代码的最小工作示例。代码下方包含一个显示问题的示例。

问题是我的脚轮非常有限(因为我使用过py::array_t)。原则上,该接口确实接受元组,列表和numpy-arrays。但是,当我基于类型名重载时,输入的元组和列表的接口将失败(即使选择了错误的类型,也仅选择了第一个重载)。

我的问题是:如何使类型转换程序更可靠?有没有一种有效的方法可以为STL-vector类重新使用尽可能多的现有类型转换程序?

C ++代码(包括pybind11接口)

#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

namespace py = pybind11;

// class definition
// ----------------

template<typename T>
class Vector
{
private:

  std::vector<T> mData;

public:

  Vector(){};
  Vector(size_t N) { mData.resize(N); };

  auto   data ()       { return mData.data (); };
  auto   data () const { return mData.data (); };
  auto   begin()       { return mData.begin(); };
  auto   begin() const { return mData.begin(); };
  auto   end  ()       { return mData.end  (); };
  auto   end  () const { return mData.end  (); };
  size_t size () const { return mData.size (); };

  std::vector<size_t> shape()   const { return std::vector<size_t>(1, mData.size()); }
  std::vector<size_t> strides() const { return std::vector<size_t>(1, sizeof(T)   ); }

  template<typename It> static Vector<T> Copy(It first, It last) {
    Vector out(last-first);
    std::copy(first, last, out.begin());
    return out;
  }
};

// C++ functions: overload based on type
// -------------------------------------

Vector<int>    foo(const Vector<int>    &A){ std::cout << "int"    << std::endl; return A; }
Vector<double> foo(const Vector<double> &A){ std::cout << "double" << std::endl; return A; }

// pybind11 type caster
// --------------------

namespace pybind11 {
namespace detail {

template<typename T> struct type_caster<Vector<T>>
{
public:

  PYBIND11_TYPE_CASTER(Vector<T>, _("Vector<T>"));

  bool load(py::handle src, bool convert)
  {
    if ( !convert && !py::array_t<T>::check_(src) ) return false;

    auto buf = py::array_t<T, py::array::c_style | py::array::forcecast>::ensure(src);
    if ( !buf ) return false;

    auto rank = buf.ndim();
    if ( rank != 1 ) return false;

    value = Vector<T>::Copy(buf.data(), buf.data()+buf.size());

    return true;
  }

  static py::handle cast(const Vector<T>& src, py::return_value_policy policy, py::handle parent)
  {
    py::array a(std::move(src.shape()), std::move(src.strides()), src.data());

    return a.release();
  }
};

}} // namespace pybind11::detail

// Python interface
// ----------------

PYBIND11_MODULE(example,m)
{
  m.doc() = "pybind11 example plugin";

  m.def("foo", py::overload_cast<const Vector<int   > &>(&foo));
  m.def("foo", py::overload_cast<const Vector<double> &>(&foo));
}

示例

import numpy as np
import example

print(example.foo((1,2,3)))
print(example.foo((1.5,2.5,3.5)))

print(example.foo(np.array([1,2,3])))
print(example.foo(np.array([1.5,2.5,3.5])))

输出:

int
[1 2 3]
int
[1 2 3]
int
[1 2 3]
double
[1.5 2.5 3.5]

1 个答案:

答案 0 :(得分:0)

一个非常简单的解决方案是专门研究pybind11::detail::list_caster。现在,类型转换程序变得像

一样容易
namespace pybind11 {
namespace detail {

template <typename Type> struct type_caster<Vector<Type>> : list_caster<Vector<Type>, Type> { };

}} // namespace pybind11::detail

请注意,这确实需要Vector拥有以下方法:

  • clear()
  • push_back(const Type &value)
  • reserve(size_t n)(在测试中似乎是可选的)

完整示例

#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

namespace py = pybind11;

// class definition
// ----------------

template<typename T>
class Vector
{
private:

  std::vector<T> mData;

public:

  Vector(){};
  Vector(size_t N) { mData.resize(N); };

  auto   data ()       { return mData.data (); };
  auto   data () const { return mData.data (); };
  auto   begin()       { return mData.begin(); };
  auto   begin() const { return mData.begin(); };
  auto   end  ()       { return mData.end  (); };
  auto   end  () const { return mData.end  (); };
  size_t size () const { return mData.size (); };

  void push_back(const T &value) { mData.push_back(value); }
  void clear() { mData.clear(); }
  void reserve(size_t n) { mData.reserve(n); }

  std::vector<size_t> shape()   const { return std::vector<size_t>(1, mData.size()); }
  std::vector<size_t> strides() const { return std::vector<size_t>(1, sizeof(T)   ); }

  template<typename It> static Vector<T> Copy(It first, It last) {
    printf("Vector<T>::Copy %s\n", __PRETTY_FUNCTION__);
    Vector out(last-first);
    std::copy(first, last, out.begin());
    return out;
  }
};

// C++ functions: overload based on type
// -------------------------------------

Vector<int>    foo(const Vector<int>    &A){ std::cout << "int"    << std::endl; return A; }
Vector<double> foo(const Vector<double> &A){ std::cout << "double" << std::endl; return A; }

// pybind11 type caster
// --------------------

    namespace pybind11 {
    namespace detail {

    template <typename Type> struct type_caster<Vector<Type>> : list_caster<Vector<Type>, Type> { };

    }} // namespace pybind11::detail

// Python interface
// ----------------

PYBIND11_MODULE(example,m)
{
  m.doc() = "pybind11 example plugin";

  m.def("foo", py::overload_cast<const Vector<double> &>(&foo));
  m.def("foo", py::overload_cast<const Vector<int   > &>(&foo));

}