我已经编写了一个Python C ++扩展,但是我的一个函数存在问题。 此扩展提供的功能将2个数组作为输入,并产生一个作为输出。
我只离开了功能代码
的相关部分float* forward(float* input, float* kernels, npy_intp* input_dims, npy_intp* kernels_dims){
float* output = new float[output_size];
//some irrelevant matrix operation code
return output;
}
包装器:
static PyObject *module_forward(PyObject *self, PyObject *args)
{
PyObject *input_obj, *kernels_obj;
if (!PyArg_ParseTuple(args, "OO", &input_obj, &kernels_obj))
return NULL;
PyObject *input_array = PyArray_FROM_OTF(input_obj, NPY_FLOAT, NPY_IN_ARRAY);
PyObject *kernels_array = PyArray_FROM_OTF(kernels_obj, NPY_FLOAT, NPY_IN_ARRAY);
if (input_array == NULL || kernels_array == NULL) {
Py_XDECREF(input_array);
Py_XDECREF(kernels_array);
return NULL;
}
float *input = (float*)PyArray_DATA(input_array);
float *kernels = (float*)PyArray_DATA(kernels_array);
npy_intp *input_dims = PyArray_DIMS(input_array);
npy_intp *kernels_dims = PyArray_DIMS(kernels_array);
/////////THE ACTUAL FUNCTION
float* output = forward(input, kernels, input_dims, kernels_dims);
Py_DECREF(input_array);
Py_DECREF(kernels_array);
npy_intp output_dims[4] = {input_dims[0], input_dims[1]-kernels_dims[0]+1, input_dims[2]-kernels_dims[1]+1, kernels_dims[3]};
PyObject* ret_output = PyArray_SimpleNewFromData(4, output_dims, NPY_FLOAT, output);
delete output;//<-----THE PROBLEMATIC LINE////////////////////////////
PyObject *ret = Py_BuildValue("O", ret_output);
Py_DECREF(ret_output);
return ret;
}
我突出显示的删除操作符是魔术发生的地方:没有它,这个函数会泄漏内存,因为内存访问冲突而导致内存崩溃。
有趣的是我写了另一个方法,返回两个数组。因此该函数返回指向两个float *元素的float **:
float** gradients = backward(input, kernels, grads, input_dims, kernel_dims, PyArray_DIMS(grads_array));
Py_DECREF(input_array);
Py_DECREF(kernels_array);
Py_DECREF(grads_array);
PyObject* ret_g_input = PyArray_SimpleNewFromData(4, input_dims, NPY_FLOAT, gradients[0]);
PyObject* ret_g_kernels = PyArray_SimpleNewFromData(4, kernel_dims, NPY_FLOAT, gradients[1]);
delete gradients[0];
delete gradients[1];
delete gradients;
PyObject* ret_list = PyList_New(0);
PyList_Append(ret_list, ret_g_input);
PyList_Append(ret_list, ret_g_kernels);
PyObject *ret = Py_BuildValue("O", ret_list);
Py_DECREF(ret_g_input);
Py_DECREF(ret_g_kernels);
return ret;
请注意,第二个示例完美无缺,没有崩溃或内存泄漏,同时在数组内置到PyArray对象后仍调用delete
。
有人可以告诉我这里发生了什么吗?
答案 0 :(得分:3)
来自PyArray_SimpleNewFromData
docs:
围绕给定指针指向的数据创建数组包装器。
如果使用PyArray_SimpleNewFromData
创建一个数组,它将围绕您提供的数据创建一个包装器,而不是制作副本。这意味着它包装的数据必须比数组更长。 delete
- 数据违反了该数据。
您有几种选择:
delete
数据之前结束。delete
数据,并使用{将数组base
设置为该对象{3}},因此数组使所有者对象保持活动状态,直到数组本身死亡。