如何将MNIST数据集转换为3通道?

时间:2019-04-20 21:57:22

标签: python tensorflow reshape mnist

我正在研究一个转移学习项目,并试图用MNIST数据集替换原始的Flower数据集以检测数字。我遵循的程序是:Original Program第9部分。任务是使用其他数据集,并查看准确性如何。在执行9.4节中的代码时遇到错误:

from random import sample

def prepare_batch(flower_paths_and_classes, batch_size):
    batch_paths_and_classes = sample(flower_paths_and_classes, batch_size)
  >>>  images = [mpimg.imread(path)[:, :, :channels] for path, labels in
              prepared_images = [prepare_image(image) for image in images]
    X_batch = 2 * np.stack(prepared_images) - 1 # Inception expects colors ranging from -1 to 1
    y_batch = np.array([labels for path, labels in batch_paths_and_classes], dtype=np.int32)
    return X_batch, y_batch 

此行中的错误是由于mnist图像的形状。当我按原样运行程序时,我得到一个

IndexError: too many indices for array

我是python / tensorflow的初学者,想知道如何将数据以正确的格式输入到预训练模型中。 如果我从上方的行中删除:channels,则不会产生错误,但是X_batch会输出(4,299,299),而输出则需要为(?,299,299,3)。由于MNIST是灰度的,因此我需要使用cv2函数,例如:

#Note the rbg it's how opencv interprets a rgb image
img = cv2.merge((r,b,g))

#saves the merged image to a file     
cv2.imwrite("rgb.jpg",img)

还是会使用重塑功能? 任何提示都很棒!

0 个答案:

没有答案