PyArray_SimpleNewFromData中的奇怪分段错误

时间:2017-01-09 16:40:14

标签: python c++ numpy cython

我的问题在“精神上”与Segmentation fault in PyArray_SimpleNewFromData

类似

我有一个如下所示的C代码:(原始代码实际测试malloc()是否返回NULL)

  1 #include <Python.h>
  2 #include <numpy/arrayobject.h>  // (Not sure if right import)
  3 #include <stdlib.h>
  4 #include <stdio.h>
  5 
  6 double *calculate_dW(npy_intp *dim_w) {
  7         int i;
  8         double* data = (double*)malloc(sizeof(double) * dim_w[0]);
  9         
 10         /* Inserts some dummy data */
 11         for (i = 0; i < dim_w[0]; i++)
 12                 data[i] = i;
 13         
 14         return data;
 15 }

然后是一个将它包装在函数中的Cython代码:

  1 import cython
  2 import numpy as np
  3 cimport numpy as np
  4 
  5 cdef extern double *calculate_dW(np.npy_intp *dim_w)
  6 
  7 def run_calculate_dW(np.ndarray[np.npy_intp, ndim=1, mode="c"] dim_w):
  8         print("Will call calculate_dW")
  9         cdef double *dW = calculate_dW(&dim_w[0])
 10 
 11         print("Will call PyArray_SimpleNewFromData")
 12         ret = np.PyArray_SimpleNewFromData(
 13                 1,
 14                 &dim_w[0],
 15                 np.NPY_FLOAT64,
 16                 dW)
 17         print("Will print")
 18         print(ret)
 19         print("Will return")
 20         return ret

我用

进行测试
  # runTest.py
  1 import numpy as np
  2 import multiply
  3 a = np.array((10,)) # as expected, using `np.array(10)` won't work
  4 print a
  5 multiply.run_calculate_dW(a)

获得以下输出

$ PYTHONPATH=build/lib.linux-x86_64-2.7/ python runTest.py 
[10]
Will call calculate_dW
Will call PyArray_SimpleNewFromData
Segmentation fault (core dumped)

(即调用PyArray_SimpleNewFromData()中的SegFault(如果我用ret = 1替换它,分段错误就会消失)。调试时,我尝试了很多东西:

  • 将尺寸数更改为1;
  • 增加malloc()分配的内存量(以保证我不会访问任何我不应该访问的内容);
  • np.NPY_FLOAT32更改为np.float32;
  • 改变我传递新阵列“形状”的方式。

我相信我完全遵循documentation以及the answer to this other question。我似乎没有得到任何编译器错误或警告。

但是,我确实注意到互联网上的所有其他代码在调用PyArray_SimpleNewFromData时都使用C(而不是Python)。我尝试从C函数返回PyObject*,但无法将其编译。

另外,我确实得到一些“使用已弃用的NumPy API,通过#defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION”警告禁用它;但我读过我可以安全地忽略它们。 (Cython Numpy warning about NPY_NO_DEPRECATED_API when using MemoryView

有什么建议吗? (还有,从dW创建一个numpy数组的任何其他方式?)

1 个答案:

答案 0 :(得分:4)

我认为问题在于,当它需要指向整数的指针时,你将Python列表作为第二个参数传递给PyArray_SimpleNewFromData。我编译时有点意外。

尝试:

ret = np.PyArray_SimpleNewFromData(
                     4,
                     &dim_w[0], # pointer to first element
                     np.NPY_FLOAT64,
                     dW)

请注意,我还将类型更改为NPY_FLOAT64,因为它应与double匹配。

我还将dim_w的定义更改为

np.ndarray[np.NPY_INTP, ndim=1, mode="c"] dim_w

确保数组的类型与numpy期望的匹配。这可能还需要将calculate_dW的签名更改为double *calculate_dW(intptr_t *dim_w)以匹配。

编辑:第二个问题是您需要包含

np.import_array()
你的Cython文件中的

(在你的导入后只是在顶层)。这为numpy做了一些设置。原则上我认为文档建议您在执行cimport numpy时始终包含它。在实践中,它有时只是重要的,这是其中之一。

(现在测试答案)