我正在尝试在Cython中实现一个NaN安全的混洗程序,它可以沿着任意维度的多维矩阵的几个轴进行混洗。
在1D矩阵的简单情况下,可以使用Fisher-Yates算法简单地将所有具有非NaN值的指数混洗:
def shuffle1D(np.ndarray[double, ndim=1] x):
cdef np.ndarray[long, ndim=1] idx = np.where(~np.isnan(x))[0]
cdef unsigned int i,j,n,m
randint = np.random.randint
for i in xrange(len(idx)-1, 0, -1):
j = randint(i+1)
n,m = idx[i], idx[j]
x[n], x[m] = x[m], x[n]
我想扩展这个算法来处理没有重新形状的大型多维数组(这会触发一个副本,以解决此处未考虑的更复杂的情况)。为此,我需要摆脱固定的输入维度,这对于Cython中的numpy数组和内存视图来说似乎都不可能。有解决方法吗?
非常感谢提前!
答案 0 :(得分:4)
感谢@Veedrac的评论,这个答案使用了更多的Cython功能。
axis
nan
values一起使用,阻止它们被排序C
有序数组创建副本。如果Fortran
有序数组,ravel()
命令将返回一个副本。这可以通过创建另一个双指针数组来改进,以携带x
的值,可能会有一些缓存惩罚...... 此代码比基于切片的其他代码至少快一个数量级。
from libc.stdlib cimport malloc, free
cimport numpy as np
import numpy as np
from numpy.random import randint
cdef extern from "numpy/npy_math.h":
bint npy_isnan(double x)
def shuffleND(x, int axis=-1):
cdef np.ndarray[double, ndim=1] v # view of x
cdef np.ndarray[int, ndim=1] strides
cdef int i, j
cdef int num_axis, pos, stride
cdef double tmp
cdef double **v_axis
if axis==-1:
axis = x.ndim-1
shape = list(x.shape)
num_axis = shape.pop(axis)
v_axis = <double **>malloc(num_axis*sizeof(double *))
for i in range(num_axis):
v_axis[i] = <double *>malloc(1*sizeof(double))
try:
tmp_strides = [s//x.itemsize for s in x.strides]
stride = tmp_strides.pop(axis)
strides = np.array(tmp_strides, dtype=np.int32)
v = x.ravel()
for indices in np.ndindex(*shape):
pos = (strides*indices).sum()
for i in range(num_axis):
v_axis[i] = &v[pos + i*stride]
for i in range(num_axis-1, 0, -1):
j = randint(i+1)
if npy_isnan(v_axis[i][0]) or npy_isnan(v_axis[j][0]):
continue
tmp = v_axis[i][0]
v_axis[i][0] = v_axis[j][0]
v_axis[j][0] = tmp
finally:
free(v_axis)
return x
答案 1 :(得分:2)
以下算法基于切片,没有复制,它应该适用于任何np.ndarray
。主要步骤是:
np.ndindex()
用于运行不同的多维索引,不包括属于您想要随机播放的轴的索引代码:
def shuffleND(np.ndarray x, axis=-1):
cdef np.ndarray[long long, ndim=1] idx
cdef unsigned int i, j, n, m
if axis==-1:
axis = x.ndim-1
all_shape = list(np.shape(x))
shape = all_shape[:]
shape.pop(axis)
for slices in np.ndindex(*shape):
slices = list(slices)
axis_slice = slices[:]
axis_slice.insert(axis, slice(None))
idx = np.where(~np.isnan(x[tuple(axis_slice)]))[0]
for i in range(idx.shape[0]-1, 0, -1):
j = randint(i+1)
n, m = idx[i], idx[j]
slice1 = slices[:]
slice1.insert(axis, n)
slice2 = slices[:]
slice2.insert(axis, m)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
x[slice1], x[slice2] = x[slice2], x[slice1]
return x