如何在Numpy中将数组展平为矩阵?

时间:2018-12-12 16:09:27

标签: python arrays numpy

我正在寻找一种优雅的方法,可以根据指定保留尺寸的单个参数将任意形状的阵列展平为矩阵。为了说明,我想

def my_func(input, dim):
    # code to compute output
    return output

例如,假设input形状为2x3x4的数组,output应该是dim=0形状为12x2的数组;用于dim=1形状为8x3的数组;用于dim=2形状为6x8的数组。如果我只想展平最后一个尺寸,则可以通过

轻松实现

input.reshape(-1, input.shape[-1])

但是我想添加添加dim的功能(非常好,无需经历所有可能的情况+检查是否存在条件等)。首先交换尺寸,以使感兴趣的尺寸变尾,然后再应用上面的操作,便有可能。

有帮助吗?

1 个答案:

答案 0 :(得分:1)

我们可以置换轴并重塑-

# a is input array; axis is input axis/dim
np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis])

从功能上讲,它基本上是将指定的轴向后推,然后重塑形状以保持该轴长以形成第二根轴,并合并其余的轴以形成第一根轴。

样品运行-

In [32]: a = np.random.rand(2,3,4)

In [33]: axis = 0

In [34]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shape
Out[34]: (12, 2)

In [35]: axis = 1

In [36]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shape
Out[36]: (8, 3)

In [37]: axis = 2

In [38]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shape
Out[38]: (6, 4)