我的问题在“精神上”与Segmentation fault in PyArray_SimpleNewFromData
类似我有一个如下所示的C代码:(原始代码实际测试malloc()
是否返回NULL)
1 #include <Python.h>
2 #include <numpy/arrayobject.h> // (Not sure if right import)
3 #include <stdlib.h>
4 #include <stdio.h>
5
6 double *calculate_dW(npy_intp *dim_w) {
7 int i;
8 double* data = (double*)malloc(sizeof(double) * dim_w[0]);
9
10 /* Inserts some dummy data */
11 for (i = 0; i < dim_w[0]; i++)
12 data[i] = i;
13
14 return data;
15 }
然后是一个将它包装在函数中的Cython代码:
1 import cython
2 import numpy as np
3 cimport numpy as np
4
5 cdef extern double *calculate_dW(np.npy_intp *dim_w)
6
7 def run_calculate_dW(np.ndarray[np.npy_intp, ndim=1, mode="c"] dim_w):
8 print("Will call calculate_dW")
9 cdef double *dW = calculate_dW(&dim_w[0])
10
11 print("Will call PyArray_SimpleNewFromData")
12 ret = np.PyArray_SimpleNewFromData(
13 1,
14 &dim_w[0],
15 np.NPY_FLOAT64,
16 dW)
17 print("Will print")
18 print(ret)
19 print("Will return")
20 return ret
我用
进行测试 # runTest.py
1 import numpy as np
2 import multiply
3 a = np.array((10,)) # as expected, using `np.array(10)` won't work
4 print a
5 multiply.run_calculate_dW(a)
获得以下输出
$ PYTHONPATH=build/lib.linux-x86_64-2.7/ python runTest.py
[10]
Will call calculate_dW
Will call PyArray_SimpleNewFromData
Segmentation fault (core dumped)
(即调用PyArray_SimpleNewFromData()中的SegFault(如果我用ret = 1
替换它,分段错误就会消失)。调试时,我尝试了很多东西:
malloc()
分配的内存量(以保证我不会访问任何我不应该访问的内容); np.NPY_FLOAT32
更改为np.float32
; 我相信我完全遵循documentation以及the answer to this other question。我似乎没有得到任何编译器错误或警告。
但是,我确实注意到互联网上的所有其他代码在调用PyArray_SimpleNewFromData时都使用C(而不是Python)。我尝试从C函数返回PyObject*
,但无法将其编译。
另外,我确实得到一些“使用已弃用的NumPy API,通过#defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION”警告禁用它;但我读过我可以安全地忽略它们。 (Cython Numpy warning about NPY_NO_DEPRECATED_API when using MemoryView)
有什么建议吗? (还有,从dW
创建一个numpy数组的任何其他方式?)
答案 0 :(得分:4)
我认为问题在于,当它需要指向整数的指针时,你将Python列表作为第二个参数传递给PyArray_SimpleNewFromData
。我编译时有点意外。
尝试:
ret = np.PyArray_SimpleNewFromData(
4,
&dim_w[0], # pointer to first element
np.NPY_FLOAT64,
dW)
请注意,我还将类型更改为NPY_FLOAT64
,因为它应与double
匹配。
我还将dim_w
的定义更改为
np.ndarray[np.NPY_INTP, ndim=1, mode="c"] dim_w
确保数组的类型与numpy期望的匹配。这可能还需要将calculate_dW
的签名更改为double *calculate_dW(intptr_t *dim_w)
以匹配。
编辑:第二个问题是您需要包含
行np.import_array()
你的Cython文件中的(在你的导入后只是在顶层)。这为numpy做了一些设置。原则上我认为文档建议您在执行cimport numpy
时始终包含它。在实践中,它有时只是重要的,这是其中之一。
(现在测试答案)