尝试在机器学习程序中使用MNIST / EMNIST字母数据集时发生ValueError

时间:2019-07-16 20:01:43

标签: python machine-learning keras mnist

我有一个机器学习程序,我正在尝试训练该程序以使用EMNIST (Extended MNIST) dataset in a wrapper on PyPi识别手写字符。我一直收到错误消息。

我已经尝试过更改描述数据集形状的元组,但是没有成功。

import os                                       # For modifying files, making directories, etc.
import keras                                    # Simplified TensorFlow library
import matplotlib.pyplot as plt                 # Image display from PyPlot
import numpy as np                              # NumPy for advanced math
import random                                   # For picking a random value within the MNIST data set
from keras.models import Sequential             # Enables sequential layers for AI learning
from keras.layers import Dense                  # Enables Dense layers
from keras.layers import Conv2D                 # Enables 2D Convoluted Neural Network layers
from keras.layers import MaxPooling2D           # Enables Maximum Pooling layers
from keras.layers import Dropout                # Enables Dropout layers
from keras.layers import Flatten                # Enables Flatten layers
from keras.datasets import mnist                # Import MNIST dataset
from emnist import extract_training_samples
from keras.callbacks import ModelCheckpoint     # Imports ModelCheckpoint for saving progress

images, labels = extract_training_samples('letters')

images = images.reshape(-1, 28, 28, 1)
print(images)
labels = keras.utils.to_categorical(labels, num_classes=27)

images = images.astype('float32')

images /= 255

labels = keras.utils.to_categorical(labels, num_classes=27)

model = Sequential()                                                            # Use sequential layers for AI training
model.add(Conv2D(32, (3, 3), activation='relu', input_shape = (28, 28, 1)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size = (2, 2)))
model.add(Flatten())
model.add(Dense(128, activation = 'relu'))
model.add(Dropout(rate = 0.3))
model.add(Dense(27, activation = 'softmax'))
model.summary()
if os.path.exists("./saves/main_EMNIST_v2.hdf5"):
    if input("Load weights from previous model fit? (Y/N)\t").lower() == "y":
        load = True
    else:
        load = False
else:
    load = False
    if not os.path.exists("./saves/"):
        os.makedirs("saves")
if load:
    print("Loading weights from previous model fit...\t", end="")
    model.load_weights("./saves/main_EMNIST_v2.hdf5")
    print("Done")
if input("Save weights from best model fit? (Y/N)\t").lower() == "y":
    save = True;
else:
    save = False;

images = images.reshape(images.shape[0], 1, 28, 28)

"""
labels = keras.utils.to_categorical(labels, 6)
y_test = keras.utils.to_categorical(y_test, 6)
"""

plt.imshow(images[54000][0], cmap='gray')
plt.show()

model.compile(loss = keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=['accuracy'])

callbacks_list = []
if save:
    filepath="./saves/main_EMNIST_v2.hdf5"
    checkpoint = ModelCheckpoint(filepath, monitor='acc', verbose=1, save_best_only=True, mode='max')
    callbacks_list = [checkpoint]

history = model.fit(images, labels, batch_size=64, epochs=int(input("Input epochs number. Recommended value is 5.\t"), 10), verbose=1, callbacks=callbacks_list)

# Code after this point has been omitted for the sake of brevity.

我收到此错误消息:

Traceback (most recent call last):
  File "OCIR_EMNIST_v2.py", line 74, in <module>
    history = model.fit(images, labels, batch_size=64, epochs=int(input("Input epochs number. Recommended value is 5.\t"), 10), verbose=1, callbacks=callbacks_list)
  File "/home/user/.local/lib/python3.7/site-packages/keras/engine/training.py", line 952, in fit
    batch_size=batch_size)
  File "/home/user/.local/lib/python3.7/site-packages/keras/engine/training.py", line 751, in _standardize_user_data
    exception_prefix='input')
  File "/home/user/.local/lib/python3.7/site-packages/keras/engine/training_utils.py", line 138, in standardize_input_data
    str(data_shape))
ValueError: Error when checking input: expected conv2d_1_input to have shape (28, 28, 1) but got array with shape (1, 28, 28)

1 个答案:

答案 0 :(得分:0)

该错误非常明显,它需要(28, 28, 1)张图像,而您正在提供(1, 28, 28)张图像。查看代码中的这两行:

images = images.reshape(-1, 28, 28, 1)
images = images.reshape(images.shape[0], 1, 28, 28)

这里的第一行是正确的,重塑为(28, 28, 1),而第二行则相反,从而产生错误。只需从代码中完全删除部分images = images.reshape(images.shape[0], 1, 28, 28)