将结构化数组传递给Cython失败了(我认为这是一个Cython错误)

时间:2015-05-04 16:40:25

标签: numpy cython

假设我有

a = np.zeros(2, dtype=[('a', np.int),  ('b', np.float, 2)])
a[0] = (2,[3,4])
a[1] = (6,[7,8])

然后我定义了相同的Cython结构

import numpy as np
cimport numpy as np

cdef packed struct mystruct:
  np.int_t a
  np.float_t b[2]

def test_mystruct(mystruct[:] x):
  cdef:
    int k
    mystruct y

  for k in range(2):
    y = x[k]
    print y.a
    print y.b[0]
    print y.b[1]

在此之后,我运行

test_mystruct(a)

我收到了错误:

ValueError                                Traceback (most recent call last)
<ipython-input-231-df126299aef1> in <module>()
----> 1 test_mystruct(a)
_cython_magic_5119cecbaf7ff37e311b745d2b39dc32.pyx in _cython_magic_5119cecbaf7ff37e311b745d2b39dc32.test_mystruct (/auto/users/pwang/.cache/ipython/cython/_cython_magic_5119cecbaf7ff37e311b745d2b39dc32.c:1364)()
ValueError: Expected 1 dimension(s), got 1

我的问题是如何修复它?谢谢。

2 个答案:

答案 0 :(得分:3)

这个pyx编译并导入ok:

import numpy as np
cimport numpy as np

cdef packed struct mystruct:
  int a[2]    # change from plain int
  float b[2]
  int c

def test_mystruct(mystruct[:] x):
  cdef:
    int k
    mystruct y

  for k in range(2):
    y = x[k]
    print y.a
    print y.b[0]
    print y.b[1]

dt='2i,2f,i'
b=np.zeros((3,),dtype=dt)
test_mystruct(b)

我从我的评论中提到的测试示例开始,并与您的案例一起玩。我认为关键的变化是将打包结构的第一个元素定义为int a[2]。因此,如果任何元素是一个数组,则第一个元素必须是一个正确设置结构的数组。

显然是测试文件没有捕获的错误。

将元素定义为int a[1]并不起作用,可能是因为dtype删除了这样的维度:

In [47]: np.dtype([('a', np.int, 1),  ('b', np.float, 2)])
Out[47]: dtype([('a', '<i4'), ('b', '<f8', (2,))])

在提出并修补问题之前,定义dtype来解决这个问题并不难。

struct可能有a[1],但数组dtype必须使用元组指定大小:('a','i',(1,))('a','i',1)的大小为()

如果其中一个struct数组是2d,看起来它们都必须是:

cdef packed struct mystruct:
  int a[1][1]
  float b[2][1]
  int c[2][2]

https://github.com/cython/cython/blob/c4c2e3d8bd760386b26dbd6cffbd4e30ba0a7d13/tests/memoryview/numpy_memoryview.pyx

退后一步,我想知道在cython中处理复杂结构化数组的重点是什么。对于某些操作而言,将字段作为单独的变量传递也不会起作用。例如myfunc(a['a'],a['b'])而不是myfunc(a)

答案 1 :(得分:1)

有一种获取c结构的dtype的通用方法,但它涉及一个临时变量:

cdef mystruct _tmp
dt = np.asarray(<mystruct[:1]>(&_tmp)).dtype

这需要至少numpy 1.5。请参阅此处的讨论:https://github.com/scikit-learn/scikit-learn/pull/2298