从numpy数组中提取所有垂直切片

时间:2018-07-13 20:20:34

标签: python numpy

我想使用ndeumerate或类似方法从3D numpy数组中提取完整切片。

arr = np.random.rand(4, 3, 3)

我想提取所有可能的arr [:, x,y],其中x,y的范围是0到2

2 个答案:

答案 0 :(得分:2)

ndindex是一种生成与形状对应的索引的便捷方法:

In [33]: arr = np.arange(36).reshape(4,3,3)
In [34]: for xy in np.ndindex((3,3)):
    ...:     print(xy, arr[:,xy[0],xy[1]])
    ...:     
(0, 0) [ 0  9 18 27]
(0, 1) [ 1 10 19 28]
(0, 2) [ 2 11 20 29]
(1, 0) [ 3 12 21 30]
(1, 1) [ 4 13 22 31]
(1, 2) [ 5 14 23 32]
(2, 0) [ 6 15 24 33]
(2, 1) [ 7 16 25 34]
(2, 2) [ 8 17 26 35]

它使用nditer,但是与嵌套的for循环对相比没有任何速度优势。

In [35]: for x in range(3):
    ...:     for y in range(3):
    ...:         print((x,y), arr[:,x,y])

ndenumerate使用arr.flat作为迭代器,但将其用于

In [38]: for xy, _ in np.ndenumerate(arr[0,:,:]):
    ...:     print(xy, arr[:,xy[0],xy[1]])

做同样的事情,迭代3x3子数组的元素。与ndindex一样,它生成索引。该元素将不是您想要的大小为4的数组,所以我忽略了它。


一种不同的方法是将后面的轴弄平,转置,然后在(新的)第一个轴上迭代:

In [43]: list(arr.reshape(4,-1).T)
Out[43]: 
[array([ 0,  9, 18, 27]),
 array([ 1, 10, 19, 28]),
 array([ 2, 11, 20, 29]),
 array([ 3, 12, 21, 30]),
 array([ 4, 13, 22, 31]),
 array([ 5, 14, 23, 32]),
 array([ 6, 15, 24, 33]),
 array([ 7, 16, 25, 34]),
 array([ 8, 17, 26, 35])]

或与以前一样打印:

In [45]: for a in arr.reshape(4,-1).T:print(a)

答案 1 :(得分:0)

为什么不只是

def recursiveSort(sensor_list,n,t): #Recursive Sort
if (n == 0):

    return sensor_list
else:
    for i in range(n-1):
        if sensor_list[i][t] > sensor_list[i + 1][t]:
            temp = sensor_list[i]
            sensor_list[i] =sensor_list[i + 1]
            sensor_list[i + 1] = temp


    return recursiveSort(sensor_list,n - 1,t)

      Dict = {'4213' : ('STEM Center', 0),
    '4201' : ('Foundations Lab', 1),
   '4204' : ('CS Lab', 2),
    '4218' : ('Workshop Room', 3),
   '4205' : ('Tiled Room', 4),
       'out' :  ('Outside', 5),
   }

    sensor_list=[]

       [ sensor_list.append((key,Dict[key][0],Dict[key][1])) for key in Dict ] #Adding 
      values to a dictionary into a tuple using list comprehension

     print(sensor_list)


    print recursiveSort(sensor_list,len(sensor_list),0)

   print recursiveSort(sensor_list,len(sensor_list),1)

   print sensor_list