为什么使用matplotlib无法正确显示CIFAR-10图像?

时间:2016-03-14 19:08:06

标签: python image matplotlib machine-learning computer-vision

从训练集中我拍了一张大小为(3,32,32)的图像('img')。 我用过plt.imshow(img.T)。图像不清晰。现在我必须对图像('img')进行更改,以使其更清晰可见。 感谢。

This is the image I got

9 个答案:

答案 0 :(得分:15)

随后打印5X5网格随机Cifar10图像。它并不模糊,但也不完美。欢迎任何建议。

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from six.moves import cPickle 

f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb')
datadict = cPickle.load(f,encoding='latin1')
f.close()
X = datadict["data"] 
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
Y = np.array(Y)

#Visualizing CIFAR 10
fig, axes1 = plt.subplots(5,5,figsize=(3,3))
for j in range(5):
    for k in range(5):
        i = np.random.choice(range(len(X)))
        axes1[j][k].set_axis_off()
        axes1[j][k].imshow(X[i:i+1][0])

答案 1 :(得分:3)

由于插值,图像模糊。要防止matplotlib模糊,请使用关键字imshow

调用interpolation='nearest'
plt.imshow(img.T, interpolation='nearest')

此外,当您使用转置时,您的x和y轴似乎正在交换,因此您可能希望显示如下:

plt.imshow(np.transpose(img, (1, 2, 0)), interpolation='nearest')

答案 2 :(得分:3)

我使用以下代码将所有CIFAR数据显示为一个大图像。代码显示图像,但是如果你想保存它而不是脱口而出我使用plt.savefig(fname, format='png', dpi=1000)

import numpy as np
import matplotlib.pyplot as plt

def reshape_and_print(self, cifar_data):
    # number of images in rows and columns
    rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
    # Image hight and width. Divide by 3 because of 3 color channels
    imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
    # reshape to number of images X color channels X image size
    # transpose to color channels X number of images X image size
    timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
    # reshape to color channels X rows X cols X image hight X image with
    # swap axis to color channels X rows X image hight X cols X image with
    timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
    # reshape to color channels X combined image hight X combined image with
    # transpose to combined image hight X combined image with X color channels
    timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)

    plt.imshow(timg)
    plt.show()

我制作了一个快速数据助手类,用于一个小型测试项目,我希望它有用:

import gzip
import pickle
import numpy as np
import matplotlib.pyplot as plt


class DataSet(object):

    def __init__(self, seed=42, setsize=10000):
        self.seed = seed
        # set the seed for reproducability
        np.random.seed(seed)
        # load the data
        train_set, test_set = self.load_data()
        # self.split_data(train_set, valid_set, test_set)
        self.split_data(train_set, test_set, setsize)

    def split_data(self, data_set, test_set, split_size):
        permutation = np.random.permutation(data_set.shape[0])
        self.train = data_set[permutation[:split_size]]
        self.valid = data_set[permutation[split_size:split_size * 2]]
        self.test = test_set[:split_size]

    def reshape_for_print(self, data):
        raise NotImplemented

    def load_data(self):
        raise NotImplemented

    def show_all_imgs(self, data):
        raise NotImplemented


class CIFAR(DataSet):

    def load_data(self):
        # try to load data
        with open('./data/cifar-100-python/train', 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        train_set = data['data'].astype(np.float32) / 255.0

        with open('./data/cifar-100-python/test', 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        test_set = data['data'].astype(np.float32) / 255.0

        return train_set, test_set

    def reshape_for_print(self, data):
        gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
        imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
        timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
        timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
        timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
        return timg

    def show_all_imgs(self, data):
        timg = self.reshape_for_print(data)
        plt.imshow(timg)
        plt.show()


class MNIST(DataSet):

    def load_data(self):
        # try to load data
        with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
            train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
        return train_set[0], test_set[0]

    def reshape_for_print(self, data):
        gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
        imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
        timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
        timg = timg.reshape(gh * imh, gw * imw)
        return timg

    def show_all_imgs(self, data):
        timg = self.reshape_for_print(data)
        plt.imshow(timg, cmap=plt.cm.gray)
        plt.show()

答案 3 :(得分:3)

要显示图像时,请确保不对数据集进行标准化。

示例:

装载机...

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.RandomHorizontalFlip(),
                         transforms.ToTensor(),
                        #  transforms.Normalize(
                        #      (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                     ])),
    batch_size=64, shuffle=True)

显示图像的代码...

img = next(iter(train_loader))[0][0]
plt.imshow(transforms.ToPILImage()(img))

归一化

Normalized

没有标准化

Not normalized

答案 4 :(得分:2)

尝试使用

import matplotlib.pyplot as plt
from scipy.misc import toimage
plt.imshow(toimage(img))

我不是100%确定代码是如何工作的,但我认为因为图像存储在浮点numpy数组中,所以imshow()函数很难将它们映射到正确的颜色。通过使用toimage()将它们转换为图像,您可以将它们转换为imshow()所期望的正确图像格式,即不是数组,而是编码为.png或.jpg的图像。

每次我想在python中显示图像时,此代码都适用于我。

答案 5 :(得分:2)

此文件读取cifar10 dataset并使用matplotlib绘制单个图像。

import _pickle as pickle
import argparse
import numpy as np
import os
import matplotlib.pyplot as plt

cifar10 = "./cifar-10-batches-py/"

parser = argparse.ArgumentParser("Plot training images in cifar10 dataset")
parser.add_argument("-i", "--image", type=int, default=0, 
                    help="Index of the image in cifar10. In range [0, 49999]")
args = parser.parse_args()


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def cifar10_plot(data, meta, im_idx=0):
    im = data[b'data'][im_idx, :]

    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    img = np.dstack((im_r, im_g, im_b))

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])         

    plt.imshow(img) 
    plt.show()


def main():
    batch = (args.image // 10000) + 1
    idx = args.image - (batch-1)*10000

    data = unpickle(os.path.join(cifar10, "data_batch_" + str(batch)))
    meta = unpickle(os.path.join(cifar10, "batches.meta"))

    cifar10_plot(data, meta, im_idx=idx)


if __name__ == "__main__":
    main()

答案 6 :(得分:2)

我创建了一个函数来绘制CIFAR10数据集中一行的RGB图像。由于图像的原始尺寸非常小(32px X 32px),因此图像最好是模糊的。

sample image

def unpickle(file):
    with open(file, 'rb') as fo:
        dict1 = pickle.load(fo, encoding='bytes')
    return dict1

pd_tr = pd.DataFrame()
tr_y = pd.DataFrame()

for i in range(1,6):
    data = unpickle('data/data_batch_' + str(i))
    pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
    tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
    pd_tr['labels'] = tr_y

tr_x = np.asarray(pd_tr.iloc[:, :3072])
tr_y = np.asarray(pd_tr['labels'])
ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
labels = unpickle('data/batches.meta')[b'label_names']

def plot_CIFAR(ind):
    arr = tr_x[ind]
    sc_dpi = 157.35
    R = arr[0:1024].reshape(32,32)/255.0
    G = arr[1024:2048].reshape(32,32)/255.0
    B = arr[2048:].reshape(32,32)/255.0

    img = np.dstack((R,G,B))
    title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111)
    ax.imshow(img,interpolation='bicubic')
    ax.set_title('Category = '+ title,fontsize =15)

plot_CIFAR(4)

答案 7 :(得分:0)

添加0.5:

plt.imshow(np.transpose(img, (1, 2, 0)) + 0.5)

答案 8 :(得分:0)

代码结果是:尝试下面的代码。 enter image description here

我发现了有关mnist和cifar图像可视化的非常有用的链接。您可以找到各种图像的代码: https://machinelearningmastery.com/how-to-load-and-visualize-standard-computer-vision-datasets-with-keras/ cifar10图片代码如下: 它运作良好。图片在上方。

[, V1 := NULL]