我正在编写一个python程序包,它可以沿任意形状的numpy
数组的任意轴执行各种复杂的统计分析任务。
当前,为了使数组的形状和轴可以是任意的,我只是对数组进行置换,以便将感兴趣的轴放置在较远的RHS上,并将LHS轴压缩为一个。例如,如果数组形状为(3,4,5)
,并且我们想沿轴1
执行某些操作,则将其转换为形状(15,4)
,沿轴{{1 }},然后将其转换回形状-1
并由函数返回。
由于所有这些数组操作,我认为这种方法可能会不必要地变慢。有没有一种方法可以对数组的除一个维度以外的所有维度进行干净的迭代?也就是说,在上面的示例中,这将变为(3,4,5)
,[0,:,0]
,...,[0,:,1]
,[2,:,3]
,但同样适用于任意数组形状和轴位置
也许可以使用[2,:,4]
,np.ndenumerate
和np.ndindex
吗?
编辑:是否可以使用np.nditer
来做到这一点?也许这可以与置换/重塑的速度相匹配。
答案 0 :(得分:0)
事实证明,只是移调和重塑确实更快。所以我想答案是……不要那样做,最好像我已经做过的那样进行置换和重塑。
这是我项目中的代码。
# Benchmark
f = lambda x: x # can change this to any arbitrary function
def test1(data, axis=-1):
# Test the lead flatten approach
data, shape = lead_flatten(permute(data, axis))
output = np.empty(data.shape)
for i in range(data.shape[0]): # iterate along first dimension; each row is an autocor
output[i,:] = f(data[i,:]) # arbitrary complex equation
return unpermute(lead_unflatten(output, shape), axis)
def test2(data, axis=-1):
# Test the new approach
output = np.empty(data.shape)
for d,o in zip(iter_1d(data, axis), iter_1d(output, axis)):
o[...] = f(d)
return output
# Iterator class
class iter_1d(object):
def __init__(self, data, axis=-1):
axis = (axis % data.ndim) # e.g. for 3D array, -1 becomes 2
self.data = data
self.axis = axis
def __iter__(self):
shape = (s for i,s in enumerate(self.data.shape) if i!=self.axis)
self.iter = np.ndindex(*shape)
return self
def __next__(self):
idx = self.iter.next()
idx = [*idx]
idx.insert(self.axis, slice(None))
return self.data[idx]
# Permute and reshape functions
def lead_flatten(data, nflat=None):
shape = list(data.shape)
if nflat is None:
nflat = data.ndim-1 # all but last dimension
if nflat<=0: # just apply singleton dimension
return data[None,...], shape
return np.reshape(data, (np.prod(data.shape[:nflat]).astype(int), *data.shape[nflat:]), order='C'), shape # make column major
def lead_unflatten(data, shape, nflat=None):
if nflat is None:
nflat = len(shape) - 1 # all but last dimension
if nflat<=0: # we artificially added a singleton dimension; remove it
return data[0,...]
if data.shape[0] != np.prod(shape[:nflat]):
raise ValueError(f'Number of leading elements {data.shape[0]} does not match leading shape {shape[nflat:]}.')
if not all(s1==s2 for s1,s2 in zip(data.shape[1:], shape[nflat:])):
raise ValueError(f'Trailing dimensions on data, {data.shape[1:]}, do not match trailing dimensions on new shape, {shape[nflat:]}.')
return np.reshape(data, shape, order='C')
def permute(data, source=-1, destination=-1):
data = np.moveaxis(data, source, destination)
return data
def unpermute(data, source=-1, destination=-1):
data = np.moveaxis(data, destination, source)
return data
这是一些%timeit
操作的结果。
import numpy as np
a = np.random.rand(10,20,30,40)
%timeit -r10 -n10 test1(a, axis=2) # around 12ms
%timeit -r10 -n10 test2(a, axis=2) # around 22ms