我想用另一个数组覆盖部分PyOpenCL数组。 我们说
import numpy as np, pyopencl.array as cla
a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,2),'int8')
现在我想做a[0:2,0:2] = b
之类的事情并希望得到
1 1 0
1 1 0
0 0 0
如果没有将所有内容复制到主机以获取速度原因,我该如何做到这一点?
答案 0 :(得分:1)
Pyopencl数组能够做到这一点 - 在这个答案时到非常有限的程度 - 使用numpy语法(即你究竟是如何编写它)的限制是:你只能使用沿第一轴切片。
import numpy as np, pyopencl.array as cla
a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,3),'int8')
# note b is 2x3 here
a[0:2]=b #<-works
a[0:2,0:2]=b[:,0:2] #<-Throws an error about non-contiguity
因此,a[0:2,0:2] = b
将不起作用,因为目标切片数组具有非连续数据。
我所知道的唯一解决方案(因为pyopencl.array类中的任何内容都无法使用切片数组/非连续数据),就是编写自己的openCL内核来做副本“手工”。
这是我用来在所有dtype的1D或2D pyopencl数组上复制的一段代码:
import numpy as np, pyopencl as cl, pyopencl.array as cla
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
kernel = cl.Program(ctx, """__kernel void copy(
__global char *dest, const int offsetd, const int stridexd, const int strideyd,
__global const char *src, const int offsets, const int stridexs, const int strideys,
const int word_size) {
int write_idx = offsetd + get_global_id(0) + get_global_id(1) * stridexd + get_global_id(2) * strideyd ;
int read_idx = offsets + get_global_id(0) + get_global_id(1) * stridexs + get_global_id(2) * strideys;
dest[write_idx] = src[read_idx];
}""").build()
def copy(dest,src):
assert dest.dtype == src.dtype
assert dest.shape == src.shape
if len(dest.shape) == 1 :
dest.shape=(dest.shape[0],1)
src.shape=(src.shape[0],1)
dest.strides=(dest.strides[0],0)
src.strides=(src.strides[0],0)
kernel.copy(queue, (src.dtype.itemsize,src.shape[0],src.shape[1]), None, dest.base_data, np.uint32(dest.offset), np.uint32(dest.strides[0]), np.uint32(dest.strides[1]), src.base_data, np.uint32(src.offset), np.uint32(src.strides[0]), np.uint32(src.strides[1]), np.uint32(src.dtype.itemsize))
a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,2),'int8')
copy(a[0:2,0:2],b)
print(a)
答案 1 :(得分:0)
在pyopencl邮件列表中,AndreasKlöckner给了我一个提示:pyopencl.array
中有一个名为multiput()
的未记录函数。语法是这样的:
cla.multi_put([arr],indices,out=[out])
'arr'是源数组,'out'是目标数组,'indices'是1D的int数组(也在设备上),其中包含行数为主的线性元素索引。
例如,在我的第一篇文章中,将'b'放入'a'的索引将是(0,1,3,4)。你只需要以某种方式将索引放在一起,并且可以使用multiput()而不是编写内核。 len(indices)
当然必须等于b.size
。还有一个take()
和multitake()
函数用于从数组中读取元素。