如何将(#dim1,#dim2,#channel)的numpy数组重塑为(#channel,#dim1,#dim2)

时间:2017-04-26 00:53:34

标签: python-2.7 numpy reshape

我有一个形状为(#dim1,#dim2,#channel)的数组。我想将其重塑为(#channel, #dim1,#dim2)

plt.reshape(x, (#channel, #dim1,#dim2))向我显示错误的图片。

1 个答案:

答案 0 :(得分:0)

如果您使用的是Cifar10数据集,则可以使用以下代码:

import numpy as np
import matplotlib.pyplot as plt
import cPickle

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = cPickle.load(fo)
    return dict

# Read the data
imageDict = unpickle('cifar-10-batches-py/data_batch_2')
imageArray = imageDict['data']

# Now we reshape
imageArray = np.swapaxes(imageArray.reshape(10000,32,32,3,order='F'), 1, 2)

# Get the labels
labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
imageLabels = [labels[i] for i in imageDict['labels']]

# Plot some images
fig, ax = plt.subplots(4,4, figsize=(8,8))
for axIndex in [(i,j) for i in range(4) for j in range(4)]:
    index = np.random.randint(0,10000)
    ax[axIndex].imshow(imageArray[index], origin='upper')
    ax[axIndex].set_title(imageLabels[index])
    ax[axIndex].axis('off')
fig.show()

哪个给你: enter image description here