Python - 获取"子阵列"的3d数组

时间:2014-10-30 09:02:32

标签: python arrays numpy

我想获得3D阵列的多个子阵列。我可以在2D情况下使用在Stack的帖子中找到的函数拆分数组:

def blockshaped(arr, nrows, ncols):
    h, w = arr.shape
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))

所以我想将它扩展到3D数组的情况,形式块作为2D arrray但是在第一维的每个切片中。我尝试使用“for循环”但不起作用......

例如:

import numpy as np

#2D case (which works)

test=np.array([[ 2.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.],
        [ 3.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.]])

def blockshaped(arr, nrows, ncols): 

    h, w = arr.shape
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))


sub = blockshaped(test, 2,2)

我得到了4个“子阵列”:

array([[[ 2.,  1.],
        [ 1.,  1.]],

       [[ 1.,  1.],
        [ 1.,  1.]],

       [[ 3.,  1.],
        [ 1.,  1.]],

       [[ 1.,  1.],
        [ 1.,  1.]]])

但是对于3D数组作为输入:

test2=np.array([[[ 2.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.],
        [ 3.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.]],

       [[ 5.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.],
        [ 2.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.]]])       

所以在这里我想要相同的分解,但在2“切片”......

def blockshaped(arr, nrows, ncols): 

    h, w, t = arr.shape 
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))

我尝试使用“for循环”但不起作用:

for i in range(test2.shape[0]):                     
    sub = blockshaped(test[i,:,:], 2, 2)

1 个答案:

答案 0 :(得分:1)

你的循环解决方案可以用来做类似的事情:

sub = np.array([blockshaped(a, 2, 2) for a in test2])

但您可以稍微修改blockshaped(),在切片前后重新整形数据:

def blockshaped(arr, nrows, ncols):
    need_reshape = False
    if arr.ndim > 2:
        need_reshape = True
    if need_reshape:
        orig_shape = arr.shape
        arr = arr.reshape(-1, arr.shape[-1])
    h, w = arr.shape
    out = (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1, 2)
               .reshape(-1, nrows, ncols))
    if need_reshape:
        new_shape = list(out.shape)
        new_shape[0] //= orig_shape[0]
        out = out.reshape([-1,] + new_shape)
    return out