pybind11从C ++修改numpy数组

时间:2019-02-20 19:03:10

标签: python pybind11

编辑:它现在可以工作,我不知道为什么。不要以为我改变了什么

我想传入并用pybind11修改一个大的numpy数组。因为它很大,所以我想避免复制它并返回一个新的。

代码如下:

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <vector>

// C++ code
void calc_sum_cost(float* ptr, int N, int M, float* ptr_cost) {
  for(int32_t i = 1; i < N; i++) {
    for(int32_t j = 1; j < M; j++) {
      float upc = ptr[(i-1) * M + j];
      float leftc = ptr[i * M + j - 1];
      float diagc = ptr[(i-1) * M + j - 1];
      float transition_cost = std::min(upc, std::min(leftc, diagc));
      if (transition_cost == diagc) {
        transition_cost += 2 * ptr_cost[i*M + j];
      } else {
        transition_cost += ptr_cost[i*M + j];
      }
      std::cout << transition_cost << std::endl;
      ptr[i * M + j] = transition_cost;
    }
  }
}

// Interface

namespace py = pybind11;

// wrap C++ function with NumPy array IO
py::object wrapper(py::array_t<float> array,
                  py::array_t<float> arrayb) {
  // check input dimensions
  if ( array.ndim()     != 2 )
    throw std::runtime_error("Input should be 2-D NumPy array");

  auto buf = array.request();
  auto buf2 = arrayb.request();
  if (buf.size != buf2.size) throw std::runtime_error("sizes do not match!");

  int N = array.shape()[0], M = array.shape()[1];

  float* ptr = (float*) buf.ptr;
  float* ptr_cost = (float*) buf2.ptr;
  // call pure C++ function
  calc_sum_cost(ptr, N, M, ptr_cost);
  return py::cast<py::none>(Py_None);
}

PYBIND11_MODULE(fast,m) {
  m.doc() = "pybind11 plugin";
  m.def("calc_sum_cost", &wrapper, "Calculate the length of an array of vectors");
}

我认为py::array::forcecast导致了转换,因此输入矩阵保持不变(在python中)。尽管删除了运行时错误,但是删除了::c_style后,它却运行了,但是再次在python中,numpy数组还是一样。

基本上我的问题是如何使用pybind11传递和修改numpy数组?

1 个答案:

答案 0 :(得分:1)

我只是遇到了同样的问题。如果从Python传递了一个与C ++参数匹配的类型的numpy数组,则没有转换发生,并且可以就地修改数据,即为numpy py::array_t<float>数组中的np.float32参数传递。如果您碰巧传入了np.float64数组(默认类型),则pybind11由于py::array::forcecast模板参数(py::array_t<T>的默认值)而进行了转换,因此您的C ++函数仅获得一个转换为numpy数组的副本,返回后所有更改都将丢失。