所以,假设我有一个2x2x2x2x2
numpy数组G
。我想创建一个切片函数,具体取决于参数a
和b
(其中a
和b
是索引)。
例如,我希望函数在G[0,:,0,:,:]
和a=0
时返回b=2
。这可能吗?
答案 0 :(得分:5)
您可以创建切片列表:
idx = [0 if i in axes else slice(None) for i in range(G.ndim)]
然后返回G[idx]
:
import numpy as np
np.random.seed(2015)
def getslice(G, axes):
idx = [0 if i in axes else slice(None) for i in range(G.ndim)]
return G[idx]
G = np.random.randint(10, size=(2,2,2,2,2,))
assert np.allclose(getslice(G, [0,2]), G[0,:,0,:,:])
答案 1 :(得分:1)
我认为@unutbu's slice based method
是内存使用率短缺的首选方法。或者,我想提出一种基于transposing
和reshaping
的方法,就像这样 -
# Get axes IDs for remaining axes
o_axes = np.setdiff1d(np.arange(G.ndim),axes)
# Transpose multi-dimensional array such that input axes are brough at th front
sliced_arr = G.transpose(np.concatenate((axes,o_axes)))
# Finally reshape to merge axes into one axis & slice to get first index from it
out = sliced_arr.reshape(np.append(-1,np.array(G.shape)[o_axes]))[0]
验证输出 -
In [23]: G = np.random.randint(0,9,(5,2,4,3,6,4,2))
...: axes = [0,2,5]
...: out_direct = G[0,:,0,:,:,0,:]
...:
In [24]: o_axes = np.setdiff1d(np.arange(G.ndim),axes)
...: sliced_arr = G.transpose(np.concatenate((axes,o_axes)))
...: out = sliced_arr.reshape(np.append(-1,np.array(G.shape)[o_axes]))[0]
...:
In [25]: np.allclose(out_direct,out)
Out[25]: True