如何根据参数创建切片数组的函数

时间:2015-10-24 02:05:11

标签: python numpy

所以,假设我有一个2x2x2x2x2 numpy数组G。我想创建一个切片函数,具体取决于参数ab(其中ab是索引)。

例如,我希望函数在G[0,:,0,:,:]a=0时返回b=2。这可能吗?

2 个答案:

答案 0 :(得分:5)

您可以创建切片列表:

idx = [0 if i in axes else slice(None) for i in range(G.ndim)]

然后返回G[idx]

import numpy as np
np.random.seed(2015)

def getslice(G, axes):
    idx = [0 if i in axes else slice(None) for i in range(G.ndim)]
    return G[idx]

G = np.random.randint(10, size=(2,2,2,2,2,))
assert np.allclose(getslice(G, [0,2]), G[0,:,0,:,:])

答案 1 :(得分:1)

我认为@unutbu's slice based method是内存使用率短缺的首选方法。或者,我想提出一种基于transposingreshaping的方法,就像这样 -

# Get axes IDs for remaining axes
o_axes = np.setdiff1d(np.arange(G.ndim),axes)

# Transpose multi-dimensional array such that input axes are brough at th front
sliced_arr = G.transpose(np.concatenate((axes,o_axes)))

# Finally reshape to merge axes into one axis & slice to get first index from it
out = sliced_arr.reshape(np.append(-1,np.array(G.shape)[o_axes]))[0]

验证输出 -

In [23]: G = np.random.randint(0,9,(5,2,4,3,6,4,2))
    ...: axes = [0,2,5]
    ...: out_direct = G[0,:,0,:,:,0,:]
    ...: 

In [24]: o_axes = np.setdiff1d(np.arange(G.ndim),axes)
    ...: sliced_arr = G.transpose(np.concatenate((axes,o_axes)))
    ...: out = sliced_arr.reshape(np.append(-1,np.array(G.shape)[o_axes]))[0]
    ...: 

In [25]: np.allclose(out_direct,out)
Out[25]: True