使用用户定义的输入进行numpy切片

时间:2017-01-03 15:17:38

标签: arrays python-3.x numpy

我有(在一个更大的项目中)numpy.array中包含的数据。 根据用户输入,我需要将选定的轴(dimAxisNr)移动到数组的第一个维度,并根据用户输入切片一个或多个(包括第一个)维度(例如示例中的Select2和Select0)。

使用此输入,我生成一个DataSelect,其中包含切片所需的信息。但切片阵列的输出大小与使用内联索引的输出大小不同。所以基本上我需要一种方法来生成' 40:2'和' 0:2'来自输入列表。

import numpy as np
dimAxisNr = 1
Select2 = [37,39]
Select0 = [0,1]


plotData = np.random.random((102,72,145,2))

DataSetSize = np.shape(plotData)
DataSelect = [slice(0,item) for item in DataSetSize]
DataSelect[2] = np.array(Select2)
DataSelect[0] = np.array(Select0)


def shift(seq, n):
    n = n % len(seq)
    return seq[n:] + seq[:n]

#Sort and Slice the data

print(np.shape(plotData))
print(DataSelect)

plotData = np.transpose(plotData, np.roll(range(plotData.ndim),-dimAxisNr))
DataSelect = shift(DataSelect,dimAxisNr)

print(DataSelect)
print(np.shape(plotData))
plotData = plotData[DataSelect]
print(np.shape(plotData))

plotDataDirect = plotData[slice(0, 72, None), 37:40:2, slice(0, 2, None), 0:2]
print(np.shape(plotDataDirect))

2 个答案:

答案 0 :(得分:0)

我不确定我完全理解你的问题......

但如果问题是“如何根据[37,39,40,23]等索引列表生成切片?”

然后我会回答:你没必要,只需按原样使用列表来选择正确的索引,如下所示:

a = np.random.rand(4,5)
print(a)
indices = [2,3,1]
print(a[0:2,indices])

请注意,列表的排序很重要:[2,3,1]会产生与[1,2,3]不同的结果

输出

>>> a
array([[ 0.47814802,  0.42069094,  0.96244966,  0.23886243,  0.86159478],
       [ 0.09248812,  0.85569145,  0.63619014,  0.65814667,  0.45387509],
       [ 0.25933109,  0.84525826,  0.31608609,  0.99326598,  0.40698516],
       [ 0.20685221,  0.1415642 ,  0.21723372,  0.62213483,  0.28025124]])
>>> a[0:2,[2,3,1]]
array([[ 0.96244966,  0.23886243,  0.42069094],
       [ 0.63619014,  0.65814667,  0.85569145]])

答案 1 :(得分:0)

我找到了问题的答案。我需要使用numpy.ix _。

以下是工作代码:

import numpy as np
dimAxisNr = 1
Select2 = [37,39]
Select0 = [0,1]


plotData = np.random.random((102,72,145,2))

DataSetSize = np.shape(plotData)
DataSelect = [np.arange(0,item) for item in DataSetSize]

DataSelect[2] = Select2
DataSelect[0] = Select0
#print(list(37:40:2))

def shift(seq, n):
    n = n % len(seq)
    return seq[n:] + seq[:n]

#Sort and Slice the data

print(np.shape(plotData))
print(DataSelect)

plotData = np.transpose(plotData, np.roll(range(plotData.ndim),-dimAxisNr))
DataSelect = shift(DataSelect,dimAxisNr)


plotDataSlice = plotData[np.ix_(*DataSelect)]
print(np.shape(plotDataSlice))

plotDataDirect = plotData[slice(0, 72, None), 37:40:2, slice(0, 2, None), 0:1]
print(np.shape(plotDataDirect))