我有一个由外部模块返回的numpy ndarray。阵列的形状是(3,3,128)。基本上是一堆128个瓷砖,每个瓷砖为3x3。
如何重新排序以使形状变为(128,3,3)。这样,通过图块编号可以更容易地进行索引。最后一步是展平到(128,9),这样就可以很容易地将128个图块中的每一个作为9值向量访问。
答案 0 :(得分:2)
您可以使用指定了新数组顺序的转置,例如
a = np.arange(0,3*3*128).reshape(3,3,128)
a_reorder = a.transpose([2,0,1])
你可以通过比较所有的瓷砖来检查它是否正确,
np.all([np.all(a[:,:,i]==a_reorder[i,:,:]) for i in range(128)])
并用
展平a_flat = a_reorder.reshape(128,9)
答案 1 :(得分:0)
将3 * 3 * 128调整为128 * 3 * 3:
y = einops.rearrange(x, 'x y tile -> tile x y')
或者我们可以在一次操作中直接将其重塑为128 * 9
y = einops.rearrange(x, 'x y tile -> tile (x y)')