如何生成Cifar-10的清晰图像

时间:2017-03-02 03:44:40

标签: python python-2.7 numpy matplotlib tensorflow

我正在使用tensorflow并尝试在Cifar-10上可视化自动编码器的输入/输出。

我在这里回答这个问题:Why CIFAR-10 images are not displayed properly using matplotlib?

这是通过稍微修改运行代码的结果(将figsize更改为5,5):

Visualized Images

然而,这仍然不如原始页面中的图像清晰明确:https://www.cs.toronto.edu/~kriz/cifar.html

我怎样才能做得更好?

2 个答案:

答案 0 :(得分:0)

这里可能有两个问题:

问题1:

看起来您的颜色通道(红色,绿色,蓝色)是混合的。这可以解释为什么颜色如此奇怪。如果是这种情况,您将需要交换阵列中的颜色通道,如下所示。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cbook import get_sample_data

rgb_image = plt.imread(get_sample_data("grace_hopper.png", asfileobj=False))

# correct color channels (R, G, B)
plt.figure()
plt.imshow(rgb_image)
plt.axis('off')

natural color

# swapped color channels (R, B, G)
rgb_image = rgb_image[:, :, [0, 2, 1]]
plt.figure()
plt.imshow(rgb_image)
plt.axis('off')

swapped color channels

问题2:

Matplotlib的plt.imshow有一个关键字参数interpolation,如果没有指定,默认为None。 Matplotlib然后引用您的本地样式表来确定默认的插值行为。根据您的样式表,这可能会导致应用插值,从而导致图像失真。请参阅documentation for imshow for more details

如果您想保证Matplotlib不会插入图片,您应在interpolation="none"中指定plt.imshow。这很令人困惑,因为默认的NoneType值None产生的行为与"none"的字符串值不同。

red = np.zeros((100, 100, 3), dtype=np.uint8)
red[:, :, 0] = 255
red[40:60, 40:60, :] = 255

# with interpolation
plt.figure()
plt.imshow(red, interpolation='bicubic') 
plt.axis('off')

with interpolation

# without interpolation
plt.figure()
plt.imshow(red, interpolation='none') 
plt.axis('off')

without interpolation

答案 1 :(得分:0)

也许你应该这样做。图像非常小,高度和宽度均为32像素,因此只有在缩略图大小时它们才会更清晰。我在这里用双三次变换对它进行插值。但你可以把它改成“没有”。所以,不是模糊,你会得到一个像素化的图像。

def unpickle(file):
    with open(file, 'rb') as fo:
        dict1 = pickle.load(fo, encoding='bytes')
    return dict1

pd_tr = pd.DataFrame()
tr_y = pd.DataFrame()

for i in range(1,6):
    data = unpickle('data/data_batch_' + str(i))
    pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
    tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
    pd_tr['labels'] = tr_y

tr_x = np.asarray(pd_tr.iloc[:, :3072])
tr_y = np.asarray(pd_tr['labels'])
ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
labels = unpickle('data/batches.meta')[b'label_names']

def plot_CIFAR(ind):
    arr = tr_x[ind]
    R = arr[0:1024].reshape(32,32)/255.0
    G = arr[1024:2048].reshape(32,32)/255.0
    B = arr[2048:].reshape(32,32)/255.0

    img = np.dstack((R,G,B))
    title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111)
    ax.imshow(img,interpolation='bicubic')
    ax.set_title('Category = '+ title,fontsize =15)

plot_CIFAR(4)