torchvision MNIST装载机是否工作不正常或我做错了什么?

时间:2017-10-15 15:29:09

标签: python mnist pytorch

我想看到我在网络中使用的图像没问题,所以我使用以下代码保存了一堆图像:

<?php include("code.text"); ?>

这不是最干净的代码,但只是保存了一堆标记为'8'的图像。在打开它们时,我发现它们中的大多数看起来像this,即使它们中的一小部分完全是fine

我做错了吗?

1 个答案:

答案 0 :(得分:0)

来自评论:

问题出现在这一行cur_img.reshape((28, 28)).astype('uint8') * 255中,在将其乘以255之前将规范化图像转换为整数,从而产生0或255的图像。

更新的代码:

train_set = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=download)

for it, (img, target) in enumerate(train_loader):
    X = Variable(img)
    tar = Variable(target)
    X = X.view(batch_size, -1)
    cur_img_batch = X.data.numpy()
    cur_tar_batch = tar.data.numpy()
    for i in range(batch_size):
        cur_img = cur_img_batch[i]
        im = Image.fromarray((cur_img.reshape((28, 28)) * 255).astype('uint8'))
        if cur_tar_batch[i] == 8:
            im.save(test_img_dir + 'iter_' + str(it) + '_sample_' + str(i) + '.png')