我正在研究一个转移学习项目,并试图用MNIST数据集替换原始的Flower数据集以检测数字。我遵循的程序是:Original Program第9部分。任务是使用其他数据集,并查看准确性如何。在执行9.4节中的代码时遇到错误:
from random import sample
def prepare_batch(flower_paths_and_classes, batch_size):
batch_paths_and_classes = sample(flower_paths_and_classes, batch_size)
>>> images = [mpimg.imread(path)[:, :, :channels] for path, labels in
prepared_images = [prepare_image(image) for image in images]
X_batch = 2 * np.stack(prepared_images) - 1 # Inception expects colors ranging from -1 to 1
y_batch = np.array([labels for path, labels in batch_paths_and_classes], dtype=np.int32)
return X_batch, y_batch
此行中的错误是由于mnist图像的形状。当我按原样运行程序时,我得到一个
IndexError: too many indices for array
我是python / tensorflow的初学者,想知道如何将数据以正确的格式输入到预训练模型中。
如果我从上方的行中删除:channels
,则不会产生错误,但是X_batch会输出(4,299,299),而输出则需要为(?,299,299,3)。由于MNIST是灰度的,因此我需要使用cv2函数,例如:
#Note the rbg it's how opencv interprets a rgb image
img = cv2.merge((r,b,g))
#saves the merged image to a file
cv2.imwrite("rgb.jpg",img)
还是会使用重塑功能? 任何提示都很棒!