将3D numpy数组拆分为较小的3D数组

时间:2019-09-27 12:36:22

标签: python arrays numpy

我有一个3D np.array

arr = np.array([ 
                [ [0, 205, 25], [210, 150, 30], [0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8, 9] ],
                [ [0, 255, 0], [255, 40, 0], [0, 0, 200], [7, 8, 9], [10, 11, 12], [120, 51, 58] ],
                [ [0, 0, 30], [0, 40, 0], [200, 100, 20], [12, 13, 14], [15, 16, 17], [13, 78, 84], ],
                [ [0, 205, 25], [210, 150, 30], [0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8, 9] ],
                [ [0, 255, 0], [255, 40, 0], [0, 0, 200], [7, 8, 9], [10, 11, 12], [120, 51, 58] ],
                [ [0, 0, 30], [0, 40, 0], [200, 100, 20], [12, 13, 14], [15, 16, 17], [13, 78, 84], ],
                [ [0, 205, 25], [210, 150, 30], [0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8, 9] ],
                [ [0, 255, 0], [255, 40, 0], [0, 0, 200], [7, 8, 9], [10, 11, 12], [120, 51, 58] ],
                [ [0, 0, 30], [0, 40, 0], [200, 100, 20], [12, 13, 14], [15, 16, 17], [13, 78, 84], ],
              ])

我需要将其拆分为3x2x3 3D阵列

[ [0, 205, 25], [210, 150, 30],    [0, 0, 0], [1, 2, 3],             [4, 5, 6], [7, 8, 9] ],
[ [0, 255, 0],  [255, 40, 0],      [0, 0, 200], [7, 8, 9],           [10, 11, 12], [120, 51, 58] ],
[ [0, 0, 30],   [0, 40, 0],        [200, 100, 20], [12, 13, 14],     [15, 16, 17], [13, 78, 84], ],

[ [0, 205, 25], [210, 150, 30],    [0, 0, 0], [1, 2, 3],             [4, 5, 6], [7, 8, 9] ],
[ [0, 255, 0],  [255, 40, 0],      [0, 0, 200], [7, 8, 9],           [10, 11, 12], [120, 51, 58] ],
[ [0, 0, 30],   [0, 40, 0],        [200, 100, 20], [12, 13, 14],     [15, 16, 17], [13, 78, 84], ],

[ [0, 205, 25], [210, 150, 30],    [0, 0, 0], [1, 2, 3],             [4, 5, 6], [7, 8, 9] ],
[ [0, 255, 0],  [255, 40, 0],      [0, 0, 200], [7, 8, 9],           [10, 11, 12], [120, 51, 58] ],
[ [0, 0, 30],   [0, 40, 0],        [200, 100, 20], [12, 13, 14],     [15, 16, 17], [13, 78, 84], ],

使用我用空格选择的3D块获得4D阵列。零元素必须为

[ 
    [[0, 205, 25], [210, 150, 30]],
    [[0, 255, 0], [255, 40, 0]],
    [[0, 0, 30], [0, 40, 0]] 
]

以此类推。

我已经阅读了this问题,但仍然不了解如何执行此操作(为什么我们需要重塑,转置和重塑以及transpose()中的神奇数字)。我可以尝试编写自己的函数,但我想知道如何以本机的方式进行操作。

1 个答案:

答案 0 :(得分:1)

您可以重塑和移置它

from numba import njit
@njit
def find_bounding_intervals(A, v):
    rows_L = []
    rows_R = []

    i = 0
    for row in range(A.shape[0]):
        while v[i] < A[row,0] and v[i] < A[row,1]:
            i += 1
        if A[row,0] <= v[i] <= A[row,1]:
            rows_L.append(A[row,0])
            rows_R.append(A[row,1])
    return np.array([rows_L, rows_R]).T