如何子类化CuPy数组?

时间:2019-02-25 01:49:19

标签: python cupy

对于NumPy,一个人可以使用

class A(np.ndarray): pass
a = np.random.rand(10, 10).view(A)
print(a) # OK

但是对于Cupy,以下操作会导致段错误:

class A(cp.ndarray): pass
a = cp.random.rand(10, 10).view(A)
print(a) # segfault

我仔细阅读了文档,但几乎找不到信息。这是预期的吗? 我正在使用CuPy 5.2,Python 3.6.8和CUDA10。

1 个答案:

答案 0 :(得分:1)

根据numpy.ndarray.view的文档,它采用两个输入参数:dtypetype。如果dtype参数是numpy.ndarray的子类,则将其解释为type参数。

另一方面,根据cupy.ndarray.view的文档,它仅接受一个输入参数:dtype。我猜cupy.ndarray.view现在不支持type参数。